Passer un Tenseur de 1 Channel à 3 Channels en PyTorch

Problème :

Comment passer d’un tenseur à 1 channel à un tenseur à 3 channels en PyTorch?

Solution :

Certains modèles, notamment pour les images, prennent en entrée des tenseurs à 3 channels (images RGB).

Hors, parfois, on veut les utiliser sur des images à 1 channel (images en nuances de gris).

Pour ce faire, utiliser torch.cat pour concaténer votre entrée trois fois avec elle-même, selon l’axe 1.

Exemple :

import torch
x = torch.cat([x, x, x], dim=1) # où x est votre tenseur initial.

Laisser un commentaire