Torch数据加载维度不一致报错

1、ERROR

RuntimeError: stack expects each tensor to be equal size, but got [238, 128] at entry 0 and [55, 128] at entry 128。

2、解决方法

  1. 方案一: 数据预处理时,最好排除特征数据维度不一致的样本;
  2. 方案二: 如果实在无法排除,又在数据加载时报错,可以尝试:
    1) 首先在get_item函数构建中判断数据维度,不符合目标维度的数据返回None类型。
         if X.shape == (238,128):
             return X,y
         else:
             return None,None
  1. 然后在数据加载的时候过滤掉None的样本
    注意: 需要import 默认的拼接方式
from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
def my_collate_fn(batch):
    '''
    batch中每个元素形如(data, label)
    '''
    # 过滤为None的数据
    batch = list(filter(lambda x:x[0] is not None, batch))
    if len(batch) == 0: return torch.Tensor()
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据        

3)数据加载

data_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         collate_fn=my_collate_fn,
                         sampler=sampler)

以上三步,即可顾虑掉维度不符合要求的样本。

3、参考文献

“pytorch 函数DataLoader“: https://blog.csdn.net/TH_NUM/article/details/80877687

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torch.Size([1, 5216544, 12, 1])变成torch.Size([16992,307, 12, 1])的方法和上面回答的一样,使用torch.reshape()函数即可。 ```python import torch # 将原始数据reshape成目标维度 data = torch.randn(1, 5216544, 12, 1) data_reshape = data.reshape(16992, 307, 12, 1) print(data_reshape.shape) ``` 接下来,我们要将新生成的data_reshape和原始数据集进行拼接,生成(16992,307,12,3)的新数据集。代码如下: ```python # 加载数据集,假设数据已经reshape成了目标维度 dataset = torch.randn(16992, 307, 12, 2) # 将数据集和新生成的data_reshape进行拼接 new_dataset = torch.cat((data_reshape.repeat(1, 307, 1, 1), dataset), dim=3) print(new_dataset.shape) ``` 在使用torch.cat()函数进行拼接时,报错信息"Sizes of tensors must match except in dimension 3. Expected size 16992 but got size 1 for tensor number 1 in the list."的意思是,两个待拼接的tensor在第四个维度上的大小不匹配,期望的大小是16992,但是实际上只有1。这是因为原始数据集的第四个维度是2,而新生成的data_reshape的第四个维度是1。我们可以通过使用torch.repeat()函数将data_reshape在第一个维度上复制16992次,使得它的第一个维度的大小与原始数据集相同,从而解决这个问题。 在上述代码中,我们首先使用data_reshape.repeat(1, 307, 1, 1)将data_reshape在第一个维度上复制307次,从而得到一个大小为[16992, 307, 12, 1]的tensor。然后,我们使用torch.cat()函数将data_reshape和原始数据集按照第四个维度进行拼接,生成了torch.Size([16992, 307, 12, 3])的新数据集。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值