一、数据读取与数据扩增
- 图像读取
PIL和OpenCV
1.1 PIL
from PIL import Image
img=Image.open('cat.jpg')
1.2 OpenCV
import cv2
img=cv2.imread('cat.jpg')
2.数据扩增
数据扩增是本次比赛的关键,在简单扩增的情况下,训练非常容易过拟合。尝试增加更多的合适的扩增方法
3.Pytorch 读取数据
Dataset是对数据集的封装,提供索引读取数据的方式
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
DataLoader是对Dataset的封装,提供迭代读取方式
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.RandomCrop((60, 120)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=40,
shuffle=True,
num_workers=1,
)
二、小结
pytorch提供的Dataset类,DataLoader类提供了方便的数据集操作,位于torch.utils.data下面。
num_workers在windows系统需要修改为0