Les modèles de deep learning développés en PyTorch peuvent être très grands et donc difficiles à appréhender.
Il est alors tentant d’afficher un résumé du modèle avec lequel on travaille. Si, en Keras, l’on dispose de la méthode summary qui nous permet de faire model.summary() pour afficher un joli résumé de model, en PyTorch, nous ne disposons pas d’une telle méthode.
Dès lors, il convient de trouver une méthode alternative qui nous affichera une vue d’ensemble des caractéristiques de notre modèle.
🚩 Problème :
Comment afficher un résumé d’un modèle PyTorch ?
✅ Solution :
Une première option est d’utiliser la fonction summary de la librairie torchinfo.
Tout d’abord, via un terminal, installer la librairie torchinfo en faisant :
pip install torchinfo
Puis, si l’on souhaite afficher un résumé d’un modèle stocké dans une variable model, faire :
from torchinfo import summary
summary(model)
Une alternative qui peut être plus informative est d’utiliser la fonction summary de la librairie torchsummary cette fois :
pip install torchsummary
puis,
from torchsummary import summary
summary(model, input_size)
La contrainte, ici, est que l’on doit également spécifier la taille d’une entrée (input_size). En contre-partie, le résumé affichera la taille des variables tout au long du modèle, ainsi qu’une estimation en termes de mémoire requise.
🤠 Exemple :
On peut, par exemple, en utilisant torchvision, charger un modèle VGG-16 et regarder ce que l’on obtient :
from torchvision import models
from torchinfo import summary
model = models.vgg16()
print(summary(model))
Cela nous affiche la chose suivante :

torchinfoOn peut alors voir la liste des couches, leurs nombres de paramètres, le nombre total de paramètres, ainsi que la répartition entre les paramètres entraînables et ceux non-entraînables.
Avec l’alternative utilisant la librairie torchsummary, on pourrait préciser une input_size de (3, 224, 224) (image RGB à trois canaux, carrée de taille 224×224) :
from torchvision import models
from torchsummary import summary
model = models.vgg16()
print(summary(model, (3, 224, 224)))
On obtiendrait alors la chose suivante :

torchsummaryOn a alors, cette fois, l’indication après chaque couche de la taille courante de notre tenseur. Le -1 que l’on voit de manière récurrente symbolise la dimension du batch.
On voit également que le résumé intègre une estimation de l’espace mémoire requis par entrée.

Laisser un commentaire