东阳的学习记录,坚持就是胜利!
文章目录
Dataloader和transform的关系
dataloader是一个多进程迭代器
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)[source]
Data loader. Combines a dataset and a sampler, and providessingle- or multi-process iterators over the dataset.
建议自己debug一遍,在这里不详细写了:
- transform:在__getitem__()中调用
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
- collate_fn:collate_fn的作用是将生成的data列表转化成(B, C, H, W)的四维张量模式。
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
transform
例子:
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
# 1 Pad
transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
transform的常用方法
- 数据中心化
- 数据标准化
- 缩放
- 裁剪
- 旋转
- 翻转
- 填充
- 噪声添加
- 灰度变换
- 线性变换
- 仿射变换(对形状进行变换)
- 亮度、饱和度及对比度变换
怎么选择适合的数据增强/处理方法
原则: 让 训练集 与 测试集 更接近
(而不是瞎增强)
- 空间位置:平移、旋转
- 色彩:灰度图,色 彩 抖动
- 形状:仿射变换
- 上下文场景:遮挡 , 填充
Dataset类
- getitem():接收一个索引,返 回 一个 样 本
- collate_fn():collate_fn的作用是将生成的data列表转化成(B, C, H, W)的四维张量模式。
Dataloader类
- dataset: D ataset 类 , 决定数据从 哪 读取
及如何读取 - batchsize : 批大小
- num _works: 是否多进程读取数 据
- shuffle: 每个epoch是否乱序
- drop_last:当 样 本数 不 能被 batchsize整除时,是否舍弃最后一批数据
在Dataset中加入数据增强的方法
例子:num_rep
:表示数据集扩张倍数
def __len__(self):
return len(self.annotations) * self.num_rep
def __getitem__(self, item) -> (Study, Any, (torch.Tensor, torch.Tensor)):
item = item % len(self.annotations)
key, (v_annotation, d_annotation) = self.annotations[item]
return self.studies[key[0]], key, v_annotation, d_annotation
不要在getitem中写太多的逻辑
可以使用一个静态方法,避免多次重复调用
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info