1.环境配置
Python、Pytorch、Torchvision、Pandas、PIL等。
2.数据集
数据来源中国交通标志检测数据集,数据集包含58个类别的5998个交通标志图像。每个图像都是单个交通标志的缩放视图。选择90%数据作为训练数据,10%作为测试数据。
3.数据展示
下图为torchvision预处理后的图片,不是原图。
4.数据加载
加载数据需要继承Dataset父类,并且__getitem__()和__len__()两个方法必须要重写,getitem()方法是获取数据及数据标签,一次返回一张图片数据,len()方法是整个数据的长度,也就是图片的数量。最后通过DataLoader类加载。
class TrafficData(Dataset):
def __init__(self, path, train=True):
super(TrafficData, self).__init__()
df = pd.read_csv(os.path.join(path, 'annotations.csv'))
labels = df.category.tolist()
image_files = df.file_name.tolist()
self.path = path
del df
self.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((128, 128)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307), (0.3081))
])
if train:
self.image_files = image_files[:int(len(image_files)*0.9)]
self.labels = labels[:int(len(image_files)*0.9)]
else:
self.image_files = image_files[int(len(image_files)*0.9):]
self.labels = labels[int(len(image_files)*0.9):]
def __getitem__(self, index):
image = Image.open(os.path.join(self.path + '/images/', self.image_files[index]))
return self.transform(image), self.labels[index]
def __len__(self):
return len(self.image_files)
train_loader = DataLoader(dataset=TrafficData('../input/chinese-traffic-signs', train=True),
batch_size=32, shuffle=True, drop_last=True)
test_loader = DataLo