import torchviz
# 定义模型
print('定义模型')
class ConvLSTM(nn.Module):
def __init__(self, num_classes):
super(ConvLSTM, self).__init__()
# 省略模型的定义代码
# 创建模型实例
net = ConvLSTM(num_classes=num_classes).to(device)
# 创建一个随机输入
example_input = torch.randn(1, 1, X_train_pca.shape[1]).to(device)
# 使用torchviz可视化模型计算图
graph = torchviz.make_dot(net(example_input), params=dict(net.named_parameters()))
# 保存计算图为图片
graph.render("model_graph", format="png")
【无标题】
于 2023-10-26 08:57:21 首次发布