Créer un Modèle en PyTorch

Problème :

Comment créer un modèle en PyTorch?

Solution :

Utiliser l’objet Module de torch.nn.

Vous devez définir une classe qui hérite de Module et possède les deux méthodes suivantes:

  • __init__ pour initialiser votre objet
  • forward pour appliquer le modèle à des entrées

Voir ci-dessous un exemple générique.

Exemple :

import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
# c'est ici que vous pouvez définir les couches du réseau
# que vous utilsierez ensuite dans le forward.
def forward(self, x):
# Ici vous appliquez les couches de votre modèle à l'entrée x.
# Dans cet exemple, je n'ai défini aucune couche, donc le modèle renvoie tout simplement l'entrée.
return x

Laisser un commentaire