最近在测试网络模型,记录一下创建model.py后如何进行简单的检测和查看
BUG:
Expected more than 1 value per channel when training, got input size torch.Size([1, 64, 1, 1])
分析:
如标题 此问题是因为进行简单测试时,自己设置了BN层,进行测试时input_image.shape=[1,3,224,224]
batch_size=1 影响了BN层
解决办法:
①加入model.eval()
②调整batch_size
下面记录一下自己的模型简单测试学习
Method 1:
查看网络结构 网络层
model=de_model()
summary(model, input_size=(3, 224, 224))
from torchsummary import summary
class de_model(nn.Module):
def __init__(self):
super(de_model, self).__init__()
....
def forward(self,x):
pass
# 重点
model=de_model()
summary(model, input_size=(3, 224, 224))
Method 2:
直接输出 查看结果
model=de_model()
image=torch.rand((1,3,224,224))
result=model(image)
from torchsummary import summary
class de_model(nn.Module):
def __init__(self):
super(de_model, self).__init__()
....
def forward(self,x):
pass
# 重点
model=de_model()
model.eval()
image=torch.rand((1,3,224,224))
result=model(image)
Method 3:
tensor类型转为可视化的图像 这个办法十分重要!!!!!
unloader = transforms.ToPILImage()
image = x.cpu().clone() # clone the tensor
此处的x是指你想查看并想保存的图像名称!!!!!!!!
image = image.squeeze(0) # remove the fake batch dimension
image = unloader(image)
image.save('my_example.jpg')
path=r'00001.png'
image=Image.open(path)
image=image.resize((224, 224), Image.ANTIALIAS)
image = np.asarray(image) / 255.0
image = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
print(image.shape)
#
model=dehaze_sample()
x=model(image)
print(x.shape)
#
# tensor类型------> 可视化的图像
unloader = transforms.ToPILImage()
image = x.cpu().clone() # clone the tensor
image = image.squeeze(0) # remove the fake batch dimension
image = unloader(image)
image.save('my_example.jpg')