引言
在基于Pytorch的深度学习的代码中基本都是用Dataset和DataLoader来加载数据,本文起源于一个疑惑:看到某Dataset类中定义的collate_fn中有这么一句:
def collate_fn(self, batch)
ims, classes = list(zip(*batch)))
我一想,本来从DataLoader加载出来的batch不就是数据跟标签是分开的吗,这里怎么又要给他分开?于是研究了一波:
先放上用于演示的代码:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
x = [[1,2],[3,4],[5,6],[7,8]]
y = [[3],[7],[11],[15]]
X = torch.tensor(x).float()
Y = torch.tensor(y).float()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
X = X.to(device)
Y = Y.to(device)
class MyDataset(Dataset):
def __init__(self,x,y):
self.x = torch.tensor(x).float()
self.y = torch.tensor(y).float()
def __len__(self):
return len(self.x)
def __getitem__(self, ix):
return self.x[ix], self.y[ix]
ds = MyDataset(X, Y)
dl = DataLoader(ds, batch_size=2, shuffle=True)
class MyNeuralNet(nn.Module):
def __init__(self):
super().__init__()
self.input_to_hidden_layer = nn.Linear(2,8)
self.hidden_layer_activation = nn.ReLU()
self.hidden_to_output_layer = nn.Linear(8,1)
def forward(self, x):
x = self.input_to_hidden_layer(x)
x = self.hidden_layer_activation(x)
x = self.hidden_to_output_layer(x)
return x
mynet = MyNeuralNet().to(device)
loss_func = nn.MSELoss()
from torch.optim import SGD
opt = SGD(mynet.parameters(), lr = 0.001)
import time
loss_history = []
start = time.time()
for _ in range(50):
for data in dl:
x, y = data
opt.zero_grad()
loss_value = loss_func(mynet(x),y)
loss_value.backward()
opt.step()
loss_history.append(loss_value)
end = time.time()
print(end - start)
val_x = [[10,11]]
val_x = torch.tensor(val_x).float().to(device)
mynet(val_x)
代码是jupyter写的,用pycharm自己改改也能运行,要完成的事情就是训练一个模型计算两数之和
现在没用collate_fn这里print一下data看得出来,数据和标签是分开的,长度都是一个batch(2)的大小
现在我把collate_fn写进dataset里,直接打印一下batch看看效果(这里需要把batch return了,因为dataloader如果选择了用collate_fn参数的话最终的输出是用collate_fn指定的函数的返回值给出的(这里设置函数名同参数名,这是默认的,可以自定义函数名))
别忘改参数:
结果如下:
可以发现不止collate_fn函数内部的batch值是一个个样本对(原来是图片和真值分开),dataloader中返回的data也是一个个样本对,这就涉及到collate_fn的工作原理了:
collate_fn的作用:
-
collate_fn
: 这是一个函数或者None
。当数据集中的样本具有不同的大小或类型时,collate_fn
函数用于将样本组合成一个批次。通常,它会处理样本的填充、对齐或转换等操作,以便批次中的所有样本具有相同的形状或类型。如果样本具有相同的大小或类型,可以将collate_fn
设置为None
。 -
collate_fn所指向的函数本来就是对一个批次的数据进行操作,在数据增强,数据预处理等方面也有着广泛应用
dataloader的工作流程:
没collate_fn的情况下,
随机/不随机(shuffle)选择batch个索引传入dataset里的
__getitem__(self, idx)得到对应的数据(样本对),经过处理分别传出 数据 和 真值(样本对被自动拆开重组了)
有collate_fn的情况下,
随机/不随机(shuffle)选择batch个索引传入dataset里的
__getitem__(self, idx)得到对应的数据,将这些数据(样本对)传入collate_fn指定函数进行处理,因为没有自动的将样本对拆开重组,所以一般需要手工操作,正如文章开头的那段代码
def collate_fn(self, batch)
ims, classes = list(zip(*batch)))
一样,手工拆开重组之后对数据进行定义的操作(数据增强,其他预处理等等)再将数据和标签分别输出(这样是为了模仿不用collate_fn时将样本对自动拆开重组)
总结:
collate_fn函数会帮助批量化地自定义地处理数据,是个很好的很灵活的功能
这个视频会帮助理解