Imagenet预训练模型验证集分类

1 Imagenet验证集介绍

Imagenet验证集数据大小为6.5G,共有1000类的50000张图片。本文主要是对这1000类的50000张图片的标签信息进行处理分类汇总成一个csv表格,便于实验读入信息需要。Imagenet验证集标签整理的文件和代码链接如下所示:https://download.csdn.net/download/qq_38406029/86030944

2 Imagenet验证集处理

待处理的文件有两个,一个是imagenet_img_info.txt文件,它包含了50000张图片与100个类别的对应关系。

另一个文件是imagenet_img_info.txt文件,它包含了Imagenet数据集中1000个类别详细信息。

最终输出的用于实验的csv文件如下所示,它详细的概括了图片与其类别的关系。

3 标签分类信息代码

用imagenet_img_info.txt文件和imagenet_img_info.txt文件最终生成的csv文件的程序如下所示:

import os
import re
import json
import csv
path1 = 'imagenet_label_info.txt'
path2 = 'imagenet_img_info.txt'

f1 = open(path1,'r')
str_all = f1.read()
json_all = json.loads(str_all)
f1.close()

img_dict = {}
img_list = []
f2 = open(path2,'r')
image_list = []
for item in f2.readlines():
	ditem = item.strip().split('\t')
	img_dict[ditem[0]] = ditem[1] 
	img_list.append(ditem[0])
f2.close()

data_list = []

for img_name in img_list:
	temp_list = []
	temp_list.append(int(img_dict[img_name]))
	temp_list.append(json_all[img_dict[img_name]][0])
	temp_list.append(img_name.strip())
	temp_list.append(json_all[img_dict[img_name]][1])
	data_list.append(tuple(temp_list))



header = ['class_index', 'class', 'image_name', 'class_name']
with open('selected_imagenet.csv', 'w+', newline='') as file_obj:
		writer = csv.writer(file_obj)
		writer.writerow(header)
		for data in data_list:
			writer.writerow(data)
reader = csv.reader(open('selected_imagenet.csv', 'r'))
for row in reader:
	print(row)


4 验证集归类文件夹

import csv
import os
import shutil

path_val = 'ILSVRC2012_img_val'
img_list = os.listdir(path_val)

path_csv = 'selected_imagenet.csv'
csvFile = open(path_csv, "r")
reader = csv.reader(csvFile)
path_root = 'img_deal'

for item in reader:
        item_path = os.path.join(path_root,item[1])
        if reader.line_num == 1:
                continue
        if not os.path.exists(item_path):
                os.makedirs(item_path)
        path_img = os.path.join(path_val,item[2])
        shutil.copy(path_img, item_path)

5 Imagenet预训练模型分类

import os
import torch
import torchvision.transforms as T
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset
import csv
import numpy as np
import pretrainedmodels
import PIL.Image as Image
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
os.environ['TORCH_HOME']=r'F:\code\Imagenet\model'

class SelectedImagenet(Dataset):
	def __init__(self, imagenet_val_dir, selected_images_csv, transform=None):
		super(SelectedImagenet, self).__init__()
		self.imagenet_val_dir = imagenet_val_dir
		self.selected_images_csv = selected_images_csv
		self.transform = transform
		self._load_csv()
	def _load_csv(self):
		reader = csv.reader(open(self.selected_images_csv, 'r'))
		next(reader)
		self.selected_list = list(reader)[0:1000]
	def __getitem__(self, item):
		target, target_name, image_name, _ = self.selected_list[item]
		image = Image.open(os.path.join(self.imagenet_val_dir, image_name))
		if image.mode != 'RGB':
			image = image.convert('RGB')
		if self.transform is not None:
			image = self.transform(image)
		return image, int(target)
	def __len__(self):
		return len(self.selected_list)


model_name = 'senet'

if model_name == 'resnet':
	model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
elif model_name == 'senet':
	model =  pretrainedmodels.__dict__['senet154'](num_classes=1000, pretrained='imagenet')
else:
	print('No implemation')


batch_size = 4
model.eval()
device = 'cpu'
model.to(device)

if model_name in ['inception']:
	input_size = [3, 299, 299]
else:
	input_size = [3, 224, 224]

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
norm = T.Normalize(tuple(mean), tuple(std))
resize = T.Resize(tuple((input_size[1:])))

trans = T.Compose([
		T.Resize((256,256)),
		T.CenterCrop((224,224)),
		resize,
		T.ToTensor(),
		norm
		])

dataset = SelectedImagenet(imagenet_val_dir='data/imagenet/ILSVRC2012_img_val',
							selected_images_csv='data/imagenet/selected_imagenet.csv',
							transform=trans
							)

ori_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = 0, pin_memory = False)

correct = 0   

for ind, (ori_img , label) in enumerate(ori_loader):
	ori_img = ori_img.to(device)
	label = label.to(device)
	predict = model(ori_img)
	predicted = torch.max(predict.data, 1)[1]
	correct += (predicted == label).sum()
	print(correct)



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道2024

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值