测试网络流程是否顺畅
if __name__ == '__main__':
x=torch.randn(2,3,256,256)
net=Union_Seg_1_v1()
print(net(x).shape)
其中,torch.randn(batch_size , channel , size[0] , size[1] )
batch_size : 运行一次输入的数据量个数
channel : 输入的通道数
size : 输入图像的规模(长和宽)
首先,定义输入数据格式
然后,定义网络
将数据输入网络,并打印输出数据的格式
打印网络结构
from torchsummary import summary
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net=DPNet_v1()
model = net.to(device)
# input_size=(channel,size,size)
summary(model, input_size=(3,256,256))
需要使用torchsummary包。pip install torchsummary 或者 conda install -c ravelbio torchsummary