本文介绍如何使用pytorh利用预训练模型进行图像分类,主要参考Transfer Learning Tutorial和
具体代码可以参考Image_classification
- 下载代码文件:git clone https://github.com/chenmozxh/pytorch_studying
- 下载数据集:wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
这个数据集是imagenet的一个小子集,包含ants和bees两个分类 - 解压数据集:unzip hymenoptera_data.zip
数据集结构为:文件夹hymenoptera_data下存在训练集路径train和测试集路径test,train和test下都有ants和bees两个文件夹,即相应的图像。
- 运行python3 example1.py就开始训练了,可以看出随着epoch的加深,loss越来越小,而准确率acc越来越高
- example1.py代码解析:
数据导入,使用官方写好的torchvision.datasets.ImageFolder接口实现数据导入。这个函数只需要你提供图像所在文件夹data_dir/train和data_dir/test即可。这两个目录下分别为N个子文件夹,N为分类的类别数,每个文件夹下为这个类别的图像。这样,torchvision.datasets.ImageFloder就会返回一个列表,列表中每一个值都是一个tuple,每个tuple包含图像和标签信息def Data_loader(Data_Path): data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), #transforms.Resize(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } data_dir = Data_Path image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} class_names = image_datasets['train'].classes return dataloaders, image_datasets, class_names dataloaders, image_datasets, class_names = Data_loader('hymenoptera_data') print(image_datasets) for e in image_datasets: print(e) print(image_datasets[e]) for index, k in enumerate(image_datasets[e]): print(type(k), len(k)) print(index, k[0].size(), k[1])
transform对图像进行预处理。torchvision.transform.Compose是用来管理所有的transforms操作的。RandomSizeCrop和RandomHorizontalFlip的输入是PIL Image,也就是用python的PIL Image库读进来图像内容。而Normalize的对象是Tensor,因此需要增加一个ToTensor()用来将图像生成成Tensor。另外,transforms.Scale(256)是resize操作,目前已经被Resize取代。
ImageFolder只是返回list,list是不能作为模型输入,因此在pytorch中,用另外一个类来封装list,那就是torch.utils.data.DataLoader。这个类将list类型的输入数据,图像和标签分别封装成一个Tensor数据格式,让模型使用。
另外一个非常重要的类是torch.utils.data.Dataset,这个类是一个抽象类,在pytorch中所有和数据相关的类都要继承这个类来实现,比如torchvision.datasets.ImageFolder和torch.utils.data.DataLoader这两个类。所以,如果数据不是按照上面的格式存储是,需要自定义一个类来读取数据,自定义的这个类必须继承自torch.utils.data.Dataset这个基类。代码如下:def default_loader(path): try: img = Image.open(path) return img.convert('RGB') except: print("Cannot read image: {}".format(path)) class customData(Dataset): def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader): with open(txt_path) as input_file: lines = input_file.readlines() #self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines] #self.img_label = [int(line.strip().split('\t')[-1]) for line in lines] self.img_name = [os.path.join(img_path, line.strip()[:-2]) for line in lines] self.img_label = [int(line.strip()[-1:]) for line in lines] self.data_transforms = data_transforms self.dataset = dataset self.loader = loader def __len__(self): return len(self.img_name) def __getitem__(self, item): img_name = self.img_name[item] label = self.img_label[item] img = self.loader(img_name) if self.data_transforms is not None: try: img = self.data_transforms[self.dataset](img) except: print("Cannot transform image: {}".format(img_name)) return img, label def Data_loader(): batch_size = 4 data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } image_datasets = {x: customData(img_path='hymenoptera_data_cp/', txt_path=(x + '.txt'), data_transforms=data_transforms, dataset=x) for x in ['train', 'val']} # wrap your data and label into Tensor dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} return image_datasets, dataloaders