Créer un Dataset pour PyTorch

PyTorch fournit des librairies puissantes pour faire du deep learning avec Python.

Pour des exercices pratiques, on peut utiliser des datasets déjà pré-définis.

Néanmoins, il arrive que l’on veuille définir son propre dataset à utiliser avec PyTorch, notamment lorsque l’on travaille avec des données que l’on a collectées soi-même.

Il faut alors voir comment définir son propre dataset PyTorch, de sorte à exploiter le plein potentiel de cette librairie, notamment avec les dataloaders.

🚩 Problème :

Comment définir un dataset pour PyTorch?

Solution :

Utiliser l’objet Dataset de torch.utils.data.

Vous devez définir une classe avec trois méthodes:

  • __init__ pour initialiser votre dataset
  • __len__ pour définir ce qu’est la longueur de votre dataset
  • __getitem__ pour accéder aux éléments du dataset

Et votre classe doit hériter de Dataset.

Voir l’exemple ci-dessous pour avoir la structure générique.

🤠 Exemple :

from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
def __len__(self):
return len(self.Y)
def __getitem__(self, idx):
x = self.X[idx]
y = self.Y[idx]
return x, y
dataset = CustomDataset(X, Y) # où X et Y sont par exemple des numpy arrays
# qui contiennent respectivement les features et targets de votre jeu de données.

Laisser un commentaire