神经网络python可视化工具

Posted by Yuzhong Wu on August 17, 2024

前言

现在网络结构越来越复杂,虽然可以直接阅读代码来理解模型的前馈过程,但是一行行阅读下来会比较耗时间。如果先通过可视化工具对模型结构有大致的了解,那么再阅读模型代码也会高效很多。

神经网络可视化工具可用于检查神经网络结构设计的合理性,确认网络各个节点的tensor的维度,从而便于修改和优化模型结构。

可视化模型结构的方法

[方法1] 使用netron可视化onnx模型

安装及使用

在terminal中,使用pip安装:
pip install netron

然后创建一个python脚本test_netron.py,输入如下内容:

import netron
import torch.onnx
from torch.autograd import Variable
from torchvision.models import resnet18

model = resnet18()
x = torch.randn(5, 3, 40, 40)
model_onnx_path = "./nnet.onnx"
torch.onnx.export(model, x, model_onnx_path)
netron.start(model_onnx_path)

执行这个test_netron.py脚本,会自动打开电脑的网页浏览器,有一个页面显示网络结构: Netron plot NN structure

此外,如果想在上图的基础上,显示每一层输出的feature map的尺寸,可以尝试运行下面的脚本:

import netron
import torch
import torch.onnx
from torchvision.models import resnet18, vgg16

model = vgg16()
x = torch.randn(5, 3, 224, 224)
model_onnx_path = "./nnet.onnx"
torch.onnx.export(model, x, model_onnx_path, opset_version=20) # If opset_version is too new, the onnx.shape_inference may not work. 

import onnx
from onnx import shape_inference
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(model_onnx_path)), model_onnx_path) # use onnx.shape_inference to make later netron visualization show intermediate layers' feature map size

netron.start(model_onnx_path)

执行脚本后应该可以看到可视化结果额外显示了中间层输出的feature map尺寸(红圈处): Netron plot NN structure

[方法2] 使用torchviz可视化pytorch模型

相比netron是从向前传播的角度来观察模型各层的情况(读取onnx模型),torchviz是从反向传播的角度来观察模型结构(读取pytorch模型)。torchviz相比netron有一个重要的区别是,它可以显示你定义的层的名字(比如你在pytorch模型中定义一个层self.fc_large = Linear(),它会显示fc_large,而netron只会显示Linear)。

在terminal中安装graphviz和torchviz:

# MacOS Example
brew install graphviz
pip install torchviz

尝试运行下面的脚本:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.input = nn.Linear(in_features=8, out_features=16)
        self.hidden_1 = nn.Linear(in_features=16, out_features=16)
        self.output = nn.Linear(in_features=16, out_features=6)
        
    def forward(self, x):
        x = torch.relu(self.input(x))
        x = torch.relu(self.hidden_1(x))
        return self.output(x)

from torchviz import make_dot
model = Net()
X = torch.randn(5, 8)
y = model(X)
vis_graph = make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)

vis_graph.format = 'png'
vis_graph.render('model_arch') # save visualization result 

执行完脚本后,应该会生成一个叫model_arch.png的图片,如下图: torchviz model visualization

如果不想显示那么多内容,可以选择不在make_dot中设置show_attrs和show_saved:

vis_graph = make_dot(y.mean(), params=dict(model.named_parameters()))