一、问题原因
类型错误:变量数据必须是张量,但得到元组
二、 源代码
1. 训练数据集
for epoch in range ( 10 ) :
for i, ( images, labels) in enumerate ( train_loader) :
images = images
labels = labels
images, labels = Variable( images) , Variable( labels)
print ( images, labels)
print ( images. size( ) )
outputs = model( images)
print ( 'outputs = ' , outputs)
loss = criterion( outputs, labels)
三、测试办法
修改源码:打印images,labels类型,发现labels类型为元组类型:tuple
for epoch in range ( 10 ) :
for i, ( images, labels) in enumerate ( train_loader) :
images = images
labels = labels
images, labels = Variable( images) , Variable( labels)
print ( 'type(images) = ' , type ( images) )
print ( 'type(labels) = ' , type ( labels) )
print ( images, labels)
print ( images. size( ) )
outputs = model( images)
print ( 'outputs = ' , outputs)
loss = criterion( outputs, labels)
打印结果
四、验证数据集类型
在文件读取中,照片路径和标签类型都是str类型 在形成训练集过程中,照片类型转成了Tensor,而标签类型还是str类型 在开始训练,执行损失函数时,照片从Tensor转Variable,而标签是tuple转Variable。
五、解决方法:labels采用one hot编码
在使用损失函数时,参数output,labels都要是Variabel类型,而labels是字符串类型,采用one-hot编码的方式。 源代码:
txt_path = 'G:/train/path.txt'
num_classes = 2
batch_size = 1
class MyDataset ( Dataset) :
def __init__ ( self, txt_path, transform = None , target_transform = None ) :
fh = open ( txt_path, 'r' )
imgs = [ ]
for line in fh:
line = line. strip( '\n' )
line = line. rsplit( '\t' )
print ( 'type(line[1])=' , type ( line[ 1 ] ) )
if line[ 1 ] == 'dog' :
label= [ 0 ]
else :
label= [ 1 ]
label = torch. Tensor( label)
label = torch. zeros( batch_size, num_classes)
print ( label)
print ( 'type(label) = ' , type ( label) )
imgs. append( ( line[ 0 ] , label) )
self. imgs = imgs
self. transform = transform
arget_transform = target_transform
def __len__ ( self) :
return len ( self. imgs)
def __getitem__ ( self, index) :
img_path, label = self. imgs[ index]
img = Image. open ( img_path) . convert( 'RGB' )
img = img. resize( ( 224 , 224 ) , Image. ANTIALIAS)
if self. transform is not None :
img = self. transform( img)
return img, label
参考资料
Pytorch中,将label变成one hot编码的两种方式