torch.utils.data下的TensorDataset和DataLoader的使用
一、TensorDataset
对给定的tensor数据(样本和标签),将它们包装成dataset
'''
data_tensor (Tensor) - 样本数据
target_tensor (Tensor) - 样本目标(标签)
'''
dataset=torch.utils.data.TensorDataset(data_tensor, target_tensor)
二、DataLoader
数据加载器,组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。它可以对我们上面所说的数据集Dataset作进一步的设置。
'''
dataset (Dataset) – 加载数据的数据集。
batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则shuffle必须设置成False。
num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
pin_memory:内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。
如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
timeout:是用来设置数据读取的超时时间的,如果超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
'''
data_iter=torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False, sampler=None,
batch_sampler=None, num_workers=0,
collate_fn=None, pin_memory=False,
drop_last=False, timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
三、pytorch使用torch.nn.Sequential快速搭建神经网络
pytorch使用torch.nn.Sequential快速搭建神经网络
torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。 为了更容易理解,官方给出了一些案例:
Sequential使用实例
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
Sequential with OrderedDict使用实例
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
四、关于torch.flatten
先看函数参数:
torch.flatten(input, start_dim=0, end_dim=-1)
input: 一个 tensor,即要被“推平”的 tensor。
start_dim: “推平”的起始维度。
end_dim: “推平”的结束维度。
首先如果按照 start_dim 和 end_dim 的默认值,那么这个函数会把 input 推平成一个 shape 为 [n] 的tensor,其中 n 即 input 中元素个数。
五、Linear()
六、 激活函数