一、方法一
数据组织形式
dataset_name
----train
----val
from torchvision import datasets, models, transforms # Data augmentation and normalization for training # Just normalization for validation data_transforms = { 'train': transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),} data_dir = 'hymenoptera_data' image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': scheduler.step() model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) # deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print()
二、方法二
自定路径+txt内写入的路径
txt内容,前面是图片路径,后面是label类别
生成txt代码
# -*-coding:utf-8-*- """ @Project: googlenet_classification @File : create_labels_files.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2018-08-11 10:15:28 """ import os import os.path def write_txt(content, filename, mode='w'): """保存txt数据 :param content:需要保存的数据,type->list :param filename:文件名 :param mode:读写模式:'w' or 'a' :return: void """ with open(filename, mode) as f: for line in content: str_line = "" for col, data in enumerate(line): if not col == len(line) - 1: # 以空格作为分隔符 str_line = str_line + str(data) + " " else: # 每行最后一个数据用换行符“\n” str_line = str_line + str(data) + "\n" f.write(str_line) def get_files_list(dir): ''' 实现遍历dir目录下,所有文件(包含子文件夹的文件) :param dir:指定文件夹目录 :return:包含所有文件的列表->list ''' # parent:父目录, filenames:该目录下所有文件夹,filenames:该目录下的文件名 files_list = [] for parent, dirnames, filenames in os.walk(dir): for filename in filenames: # print("parent is: " + parent) # print("filename is: " + filename) # print(os.path.join(parent, filename)) # 输出rootdir路径下所有文件(包含子文件)信息 curr_file = parent.split(os.sep)[-1] if curr_file == '010101': labels = 0 elif curr_file == '010102': labels = 1 elif curr_file == '010103': labels = 2 elif curr_file == '010105': labels = 3 elif curr_file == '010106': labels = 4 elif curr_file == '010107': labels = 5 elif curr_file == '010201': labels = 6 elif curr_file == '010202': labels = 7 elif curr_file == '030000': labels = 8 files_list.append([os.path.join(curr_file, filename), labels]) return files_list if __name__ == '__main__': train_dir = r'F:\WU_work\guandao\data\guandao20190904_10\train' train_txt = r'F:\WU_work\guandao\data\guandao20190904_10/train.txt' train_data = get_files_list(train_dir) write_txt(train_data, train_txt, mode='w') val_dir = r'F:\WU_work\guandao\data\guandao20190904_10\validation' val_txt = r'F:\WU_work\guandao\data\guandao20190904_10/val.txt' val_data = get_files_list(val_dir) write_txt(val_data, val_txt, mode='w')
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
#img_path是训练集或验证集路径,如F:\WU_work\guandao\data\guandao20190904_10\train train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform) valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform)
数据加载
# -------------------------------------------- step 1/5 : 加载数据 ------------------------------------------- train_txt_path = './Data/train.txt' valid_txt_path = './Data/valid.txt' # 数据预处理设置 normMean = [0.4948052, 0.48568845, 0.44682974] normStd = [0.24580306, 0.24236229, 0.2603115] normTransform = transforms.Normalize(normMean, normStd) trainTransform = transforms.Compose([ transforms.Resize(224), transforms.RandomCrop(224, padding=4), transforms.ToTensor(), normTransform ]) validTransform = transforms.Compose([ transforms.ToTensor(), normTransform ]) # 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制 train_data = MyDataset(img_path = '', txt_path=train_txt_path, transform=trainTransform) valid_data = MyDataset(img_path = '', txt_path=valid_txt_path, transform=validTransform) # 构建DataLoder train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=16, shuffle=True) valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=16)
train_loader 是迭代器,每次返回图片和对应的label