街景字符编码识别 - Task2 - 数据读取与数据扩增
图像读取
OpenCV读取图片
import cv2
# 导入Opencv库
img = cv2.imread('cat.jpg')
# Opencv默认颜色通道顺序是BRG,转换一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
使用cv.imread()函数读取图像。图像应该在工作目录或图像的完整路径应给出。
第二个参数是一个标志,它指定了读取图像的方式。
- cv.IMREAD_COLOR: 加载彩色图像。任何图像的透明度都会被忽视。它是默认标志。
- cv.IMREAD_GRAYSCALE:以灰度模式加载图像
- cv.IMREAD_UNCHANGED:加载图像,包括alpha通道
注意 除了这三个标志,你可以分别简单地传递整数1、0或-1。
OpenCV显示图片
使用函数cv.imshow()在窗口中显示图像。窗口自动适合图像尺寸。
第一个参数是窗口名称,它是一个字符串。第二个参数是我们的对象。你可以根据需要创建任意多个窗口,但可以使用不同的窗口名称。
cv.imshow('image',img)
cv.waitKey(0)
cv.destroyAllWindows()
数据扩增
(找到了阿水哥发的文章)
在深度学习中数据扩增方法非常重要,数据扩增可以增加训练集的样本,同时也可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力。
数据扩增方法
在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。
以torchvision为例,常见的数据扩增方法包括:
- transforms.CenterCrop 对图片中心进行裁剪
- transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
- transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
- transforms.Grayscale 对图像进行灰度变换
- transforms.Pad 使用固定值进行像素填充
- transforms.RandomAffine 随机仿射变换
- transforms.RandomCrop 随机区域裁剪
- transforms.RandomHorizontalFlip 随机水平翻转
- transforms.RandomRotation 随机旋转
- transforms.RandomVerticalFlip 随机垂直翻转
在本次赛题中,赛题任务是需要对图像中的字符进行识别,因此对于字符图片并不能进行翻转操作。比如字符6经过水平翻转就变成了字符9,会改变字符原本的含义。
常用的数据扩增库
-
torchvision
https://github.com/pytorch/vision
pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等; -
imgaug
https://github.com/aleju/imgaug
imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快; -
albumentations
https://albumentations.readthedocs.io
是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。
pytorch读取数据
重载Dataset
# 重载torch.utils.data.dataset
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform = None):
self.img_path = img_path # 图片路径
self.img_label = img_label # 标签路径
if transform is not None: # 数据变换、增广
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB') # BGR -> RGB
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype = np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10] # 扩增位5位 填充10为空白无数字
return img, torch.from_numpy(np.array(lbl[:]))
def __len__(self):
return len(self.img_path) # 数据len
定义Dataset
# 定义好训练数据和验证数据的Dataset
path = './The Street View House Numbers Dataset/'
train_path = glob.glob(path + 'mchar_train/*.png')
train_path.sort() # 排序 00000.png 0001.png ...
train_json = json.load(open(path + 'mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
# 定义Dataset
train_dataset = SVHNDataset(train_path, train_label,
transform = transforms.Compose([
# 缩放到固定尺寸
transforms.Resize((64, 128)),
# 随机颜色变换
transforms.ColorJitter(0.2, 0.2, 0.2),
# 加入随机旋转
transforms.RandomRotation(5),
# 将图片转换为pytorch 的tesntor
transforms.ToTensor(),
# 对图像像素进行归一化
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]))
# 定义DataLoader 返回批量数据iteration
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size = 40, # 批量大小40
shuffle = True, # 随机读取
)
读取一次数据看一下
for img, label in train_loader:
# 这里删了数据增广 将图片变回Normalize前
img = np.array(img[0].permute(1,2,0)) * np.array([0.229,0.224,0.225]).reshape((1, 1, 3)) + np.array([0.485,0.456,0.406]).reshape((1, 1, 3))
print(img.shape)
plt.figure(figsize = (10, 10))
plt.plot(111)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
print(label[0])
break