1 Imagenet验证集介绍
Imagenet验证集数据大小为6.5G,共有1000类的50000张图片。本文主要是对这1000类的50000张图片的标签信息进行处理分类汇总成一个csv表格,便于实验读入信息需要。Imagenet验证集标签整理的文件和代码链接如下所示:https://download.csdn.net/download/qq_38406029/86030944
2 Imagenet验证集处理
待处理的文件有两个,一个是imagenet_img_info.txt文件,它包含了50000张图片与100个类别的对应关系。
另一个文件是imagenet_img_info.txt文件,它包含了Imagenet数据集中1000个类别详细信息。
最终输出的用于实验的csv文件如下所示,它详细的概括了图片与其类别的关系。
3 标签分类信息代码
用imagenet_img_info.txt文件和imagenet_img_info.txt文件最终生成的csv文件的程序如下所示:
import os
import re
import json
import csv
path1 = 'imagenet_label_info.txt'
path2 = 'imagenet_img_info.txt'
f1 = open(path1,'r')
str_all = f1.read()
json_all = json.loads(str_all)
f1.close()
img_dict = {}
img_list = []
f2 = open(path2,'r')
image_list = []
for item in f2.readlines():
ditem = item.strip().split('\t')
img_dict[ditem[0]] = ditem[1]
img_list.append(ditem[0])
f2.close()
data_list = []
for img_name in img_list:
temp_list = []
temp_list.append(int(img_dict[img_name]))
temp_list.append(json_all[img_dict[img_name]][0])
temp_list.append(img_name.strip())
temp_list.append(json_all[img_dict[img_name]][1])
data_list.append(tuple(temp_list))
header = ['class_index', 'class', 'image_name', 'class_name']
with open('selected_imagenet.csv', 'w+', newline='') as file_obj:
writer = csv.writer(file_obj)
writer.writerow(header)
for data in data_list:
writer.writerow(data)
reader = csv.reader(open('selected_imagenet.csv', 'r'))
for row in reader:
print(row)
4 验证集归类文件夹
import csv
import os
import shutil
path_val = 'ILSVRC2012_img_val'
img_list = os.listdir(path_val)
path_csv = 'selected_imagenet.csv'
csvFile = open(path_csv, "r")
reader = csv.reader(csvFile)
path_root = 'img_deal'
for item in reader:
item_path = os.path.join(path_root,item[1])
if reader.line_num == 1:
continue
if not os.path.exists(item_path):
os.makedirs(item_path)
path_img = os.path.join(path_val,item[2])
shutil.copy(path_img, item_path)
5 Imagenet预训练模型分类
import os
import torch
import torchvision.transforms as T
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
import csv
import numpy as np
import pretrainedmodels
import PIL.Image as Image
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ['TORCH_HOME']=r'F:\code\Imagenet\model'
class SelectedImagenet(Dataset):
def __init__(self, imagenet_val_dir, selected_images_csv, transform=None):
super(SelectedImagenet, self).__init__()
self.imagenet_val_dir = imagenet_val_dir
self.selected_images_csv = selected_images_csv
self.transform = transform
self._load_csv()
def _load_csv(self):
reader = csv.reader(open(self.selected_images_csv, 'r'))
next(reader)
self.selected_list = list(reader)[0:1000]
def __getitem__(self, item):
target, target_name, image_name, _ = self.selected_list[item]
image = Image.open(os.path.join(self.imagenet_val_dir, image_name))
if image.mode != 'RGB':
image = image.convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, int(target)
def __len__(self):
return len(self.selected_list)
model_name = 'senet'
if model_name == 'resnet':
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
elif model_name == 'senet':
model = pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained='imagenet')
else:
print('No implemation')
batch_size = 4
model.eval()
device = 'cpu'
model.to(device)
if model_name in ['inception']:
input_size = [3, 299, 299]
else:
input_size = [3, 224, 224]
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
norm = T.Normalize(tuple(mean), tuple(std))
resize = T.Resize(tuple((input_size[1:])))
trans = T.Compose([
T.Resize((256,256)),
T.CenterCrop((224,224)),
resize,
T.ToTensor(),
norm
])
dataset = SelectedImagenet(imagenet_val_dir='data/imagenet/ILSVRC2012_img_val',
selected_images_csv='data/imagenet/selected_imagenet.csv',
transform=trans
)
ori_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = 0, pin_memory = False)
correct = 0
for ind, (ori_img , label) in enumerate(ori_loader):
ori_img = ori_img.to(device)
label = label.to(device)
predict = model(ori_img)
predicted = torch.max(predict.data, 1)[1]
correct += (predicted == label).sum()
print(correct)