【pytorch】一个DataSet问题的解决

最近使用Hugging Face的BERT模型,遇到了数据加载上的一个问题。

“Unable to create tensor, you should probably activate truncation and/or padding with ‘padding=True’ ‘truncation=True’ to have batched tensors with the same length.”

是在DataLoader中声明的DataSet参数

    train_data_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)

这是DataSet的定义


class MyDataSet(Dataset):
 def __init__(self, path):
     self.text_list = []
     self.label_list = []
     self.tokenizer = BertTokenizer.from_pretrained(path)
     with open(path, 'r',encoding='utf-8') as f:
         text, label = f.readline().split('\t')
         label=torch.tensor([ int(label.strip())])
         tokenizer =self.tokenizer(text, padding='max_length',return_token_type_ids=True, return_attention_mask=True,max_length=128, truncation=True, return_tensors='pt')
         self.text_list.append(tokenizer)
         self.label_list.append(label)

 def __getitem__(self, item):
     return self.text_list[item], self.label_list[item]

 def __len__(self):
     return len(self.text_list)

很诧异,到底哪里错了。网上的资料也都语焉不详。
直到我发现把BERT的路径搞错了。

        self.tokenizer = BertTokenizer.from_pretrained(path)

改为

        self.tokenizer = BertTokenizer.from_pretrained(bert_path)
        #这里的bert_path是我电脑上定义的路径,path是要加载的数据的路径。

希望这个报错信息能带给你一些启发。

Pytorch中遍历dataset可以使用torch.utils.data.DataLoader这个类。在初始化DataLoader时,一般常用的参数有dataset、batch_size、shuffle和num_workers等。其中dataset就是我们构建的自定义dataset类。在使用时,可以直接使用for循环来遍历dataloader对象,并且可以通过迭代器的方式输出每个batch的数据。具体实现如下: ```python import torch from torch.utils.data import DataLoader # 创建自定义的dataset对象 dataset = MyDataset() # 创建dataloader对象,并指定batch_size和是否进行数据打乱 dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # 遍历dataloader对象 for batch_data in dataloader: # 处理每个batch的数据 inputs, labels = batch_data # 进行模型的训练或预测等操作 ... ``` 在遍历dataloader时,实际上是从dataset中取出数据,只是在取数据的规则上进行了一些修改,比如可以进行数据的打乱操作。因此,在遍历dataloader时,会调用自己定义的dataset类中的__getitem__()方法来获取数据。通过这种方式,我们可以方便地对数据进行mini-batch的训练。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [使用Pytorch中的Dataset类构建数据集的方法及其底层逻辑](https://blog.csdn.net/rowevine/article/details/123631144)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* [对pytorch中的dataset和dataloader的一些理解](https://blog.csdn.net/weixin_45700881/article/details/128351086)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值