import os, sys, glob, shutil, json
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
train_path = glob.glob('Dataset/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('Dataset/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
data = SVHNDataset(train_path, train_label,
transforms.Compose([
# 缩放到固定尺寸
transforms.Resize((64, 128)),
# 随机颜色变换
transforms.ColorJitter(0.2, 0.2, 0.2),
# 加入随机旋转
transforms.RandomRotation(5),
# 将图片转换为pytorch 的tesntor
# transforms.ToTensor(),
# 对图像像素进行归一化
# transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]))
img_num = data.__len__()
print(img_num)
if img_num > 9:
img_num = 9
index = 1
for i in range(1, img_num + 1):
plt.subplot(3, 3, i)
img, lbl = data.__getitem__(i)
lbl_c = lbl.numpy()
lbl_n = lbl_c[lbl_c < 10]
s = ''
for j in range(lbl_n.size):
s = s + str(lbl_n[j])
plt.title(s)
plt.imshow(img)
plt.axis('off')
仅展示前9张图片