出现这个Runtime error 很有可能是因为一个 batch 内每条数据长度不一致,检查是否需要 pad 或者 truncate 是否有问题。
我在使用 transformers 的 Berttokenizer 处理句子对是遇到这个问题,不同的数据有不同的特点,根据具体情况调整 truncate 策略。
encoded_pair = self.tokenizer(sent1, sent2,
padding='max_length', # Pad to max_length
truncation=TRUE, # Truncate to max_length TRUE
max_length=self.maxlen,
return_tensors='pt') # Return torch.Tensor objects
把truncation方式改为’longest_first’,问题解决。
encoded_pair = self.tokenizer(sent1, sent2,
padding='max_length', # Pad to max_length
truncation='longest_first',
max_length=self.maxlen,
return_tensors='pt') # Return torch.Tensor objects