使用torchvision.datasets.ImageFolder来读取图片
torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])
- root : 指定图片存储的路径,在下面的例子中是’./data/dogcat_2’
- transform: 一个函数,原始图片作为输入,返回一个转换后的图片。
- target_transform - 一个函数,输入为target,输出对其的转换。例子,输入的是图片标注的string,输出为word的索引。
有几个变量 - self.classes - 用一个list保存 类名
- self.class_to_idx - 类名对应的 索引
- self.imgs - 保存(img-path, class) tuple的list
定义一个torch.utils.data.Dataset数据类
Dataset有两个函数
- _getitem_
- _len_
class Dataset_1(torch.utils.data.Dataset):
def __init__(self,root,is_resize=False,is_transfrom=False):
self.root=root
self.is_resize=is_resize
self.is_transfrom=is_transfrom
self.imgs_list=...#保存图片路径节省内存
self.labs_list=...
def __getitem__(self, index):
img_path,lab=self.imgs_list[index],self.labs_list[index]
img_data = Image.open(img_path)
if self.is_transfrom:
img_data=self.is_transfrom(img_data)
return img_data,lab
def __len__(self):
return len(self.imgs_list)
定义好Dataset数据类,之后使用DataLoader导入
torch.utils.data.DataLoader(dataset=Dataset_1, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
有时需要对数据进行处理
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()
])
img = Image.open('test.png')
train_transforms(img)
训练数据
train_loader = torch.utils.data.DataLoader(dataset=Dataset_1, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
model = torchvision.models.__dict__['resnet101'](pretrained=True)
model.load_state_dict(torch.load('...pth'))
model.to(DEVICE)
# 训练
optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)
loss_fn = torch.nn.MSELoss()
model.train()
for batch_index,(data,target) in enumerate(train_loader):
data, target = data.to(DEVICE), target.to(DEVICE)
output = model(data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 推理
model.eval()
output = model(data)