pytorch使用中出现以上问题解决办法
以上错误是因为网络部分模块用列表list放置,没有被nn.ModuleList()处理;导致pytorch无法识别,也就没放到cuda上;
利用nn.ModuleList()可解决:
class OneModule(nn.Module):
def __init__(self):
super(OneModule, self).__init__()
self.layers = []#RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
self.layers.append(block(self.res_inplanes, planes, stride, downsample))
self.res_inplanes = planes * block.expansion
for i in range(1, blocks):
self.layers.append(block(self.res_inplanes, planes))
self.layers = nn.ModuleList(self.layers)#缺少这行代码就会报错
def forward(self, x):
pass