NLP中处理数据用到了DataLoader,在定义Collate_fn的时候,一直出现ValueError,分别是:
ValueError:not enough values to unpack(expected 2, got 1)
VauleError:too much values to unpack (expected 2)
出错的对应行是:
def collate_fn(batch):
for x,y in batch //出错行
... ...
... ...
return ... ...
(关于collate_fn的input和output这里简要介绍,XJTU-Qidong的: pytorch中collate_fn函数的使用&如何向collate_fn函数传参介绍的非常详细。)
出错的是batch行,那么就从这里入手。Collate_fn的input是batch,而batch来源于class dataset的__getitem__函数,class dataset结构如下:
class dataset(dataset):
def __init__(self,path):
self.x = []
self.y = []
def __len__(self):
... ...
def__getitem__(self,idx):
... ...
return self.x,self.y
x,y即为batch内容,batch的shape是(batch_size,2)。问题便出现在这里, 返回值x和y是list形式,而collate_fn需要的输入是dict形式,所以需要把__getitem__段return部分改为:
return {'x':self.x[idx], 'y':self.y[idx]}
可以用以下的代码段检测输出
dataset = dataset() //实例化class dataset()
dataset.__init__(path)
dataset[idx]['x']
dataset[idx]['y']
输出应为某一对被提取的x和y值。