from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Cxy(nn.Module):
def __init__(self):
super(Cxy, self).__init__()
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10),
)
def forward(self,x):
x=self.model1(x)
return x
cxy=Cxy()
print(cxy)
input=torch.ones((64,3,32,32))
output=cxy(input)
print(output.shape)
writer=SummaryWriter("../logs_seq")
writer.add_graph(cxy,input)
writer.close()
hon
在这里插入代码片
注意tensorboard的使用必须先在terminal中激活进行:conda activate pytorch
再输入tensorboard --logdir=logs_seq
就会出现这种图片
可视化