Pytorch 继承 Dataset 加载自己定义的数据
首先介绍自己的 Mydataset
import os
import glob
import csv
import random
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class Mydataset(Dataset):
def __init__(self, root, resize, mode):
super(Mydataset, self).__init__()
self.root = root
self.resize = resize
self.name2label = {} # 0,1,2 ...
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images, self.labels = self.load_csv('imagess.csv')
if mode == 'train': # %60 = %0->%60
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # %20 = %60->%80
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # %20 = %80->%100
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
write = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.name2label[name]
write.writerow([img, label])
print('writen into csv file:', filename)
# read csv
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# idx-[0->len(images)]
img, label = self.images[idx], self.labels[idx]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img = tf(img)
label = torch.tensor(label)
return img, label
def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hat = (x - mean) / std
# x = x_hat * std + mean
# x:[x,h,w]
# mean: [3] -> [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x_hat * std + mean
return x
def main():
import visdom
import time
import torchvision
viz = visdom.Visdom()
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
tmp = torchvision.datasets.ImageFolder(root='dataset', transform=transform)
loader = DataLoader(tmp, batch_size=32, shuffle=True)
for x, y in loader:
viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == "__main__":
main()
基于 resnet18
如何加载数据训练,首先完成一个 Flatten.py
的函数
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2,3, i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title('{}: {}'.format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
完成 train_resnrt18.py 训练程序
import torch
import visdom
import torch.nn as nn
import torch.optim
from mydataset import Mydataset
from torch.utils.data import Dataset, DataLoader
from Flatten import Flatten
from torchvision.models.resnet import resnet18
batchsize = 32
learning_rate = 1e-5
epoches = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_db = Mydataset('datasets', 32, mode='train')
val_db = Mydataset('datasets', 32, mode='val')
test_db = Mydataset('datasets', 32, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsize, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsize, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsize, num_workers=2)
# 训练模型
viz = visdom.Visdom()
def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total
def main():
model = resnet18(pretrained=True) # 比较好的 model
model = nn.Sequential(*list(model.children())[:-1], # [b, 512, 1, 1] -> 接全连接层
Flatten(), # [b, 512, 1, 1] -> [b, 512]
nn.Linear(512, 2)).to(device) # 添加全连接层
# x = torch.randn(2, 3, 224, 224)
# print(model(x).shape)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义迭代参数的算法
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epoches):
for step, (x, y) in enumerate(train_loader):
viz.images(train_db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
viz.line([val_acc], [global_step], win='val_acc', update='append')
print("best acc:", best_acc, "best epoch:", best_epoch)
torch.save(model.state_dict(), 'resnet18-circle25-50.pkl')
print("loaded from ckpt!")
test_acc = evaluate(model, test_loader)
print("test acc:", test_acc)
if __name__ == "__main__":
main()
使用 visdom 进行可视化,完成物体的识别.