最近在练习pytorch使用.
首先下载猫狗数据:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
提取码:2xq4
然后写代码,感觉这种逐批次从硬盘取数据训练有点慢,但先跑起来吧.感兴趣的可以去看看resnet源码,最好自己手敲一遍,练习效果更好.源码如下:resnet代码分析 - 慢行厚积 - 博客园 先熟悉简单的pytorch接口,后面来搞高级的检测和分割.
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import transforms,datasets,models
import shutil
random_state = 42
np.random.seed(random_state)
original_dataset_dir = '/home/wangyunhao/dc/dogs-vs-cats/train/train'
total_num = int(len(os.listdir(original_dataset_dir))/2)
random_idx = np.array(range(total_num))
np.random.shuffle(random_idx)
base_dir = '/home/wangyunhao/dc/dog_cat_deal'
if not os.path.exists(base_dir):
os.mkdir(base_dir)
sub_dirs = ['train','test']
animals = ['cats','dogs']
train_idx = random_idx[:int(total_num*0.9)]
test_idx = random_idx[int(total_num*0.9):]
numbers = [train_idx,test_idx]
for idx,sub_dir in enumerate(sub_dirs):
dir = os.path.join(base_dir,sub_dir)
if not os.path.exists(dir):
os.mkdir(dir)
for animal in animals:
animal_dir = os.path.join(dir,animal)
if not os.path.exists(animal_dir):
os.mkdir(animal_dir)
fnames = [animal[:-1] + '.{}.jpg'.format(i) for i in numbers[idx]]
for fname in fnames:
src = os.path.join(original_dataset_dir,fname)
dst = os.path.join(animal_dir,fname)
shutil.copyfile(src,dst)
print(animal_dir+ ' total images : %d ' %(len(os.listdir(animal_dir))))
random_state = 1
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)
np.random.seed(random_state)
epochs = 10
batch_size = 10
num_workers = 0
use_gpu = torch.cuda.is_available()
model_path = '/home/wangyunhao/dc/dc_dog.pt'
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std = [0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root = '/home/wangyunhao/dc/dog_cat_deal/train/',
transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
test_dataset = datasets.ImageFolder(root='/home/wangyunhao/dc/dog_cat_deal/test', transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
net = models.resnet101(num_classes=2)
if(os.path.exists('/home/wangyunhao/dc/dc_dog.pt')):
net = torch.load('/home/wangyunhao/dc/dc_dog.pt')
if use_gpu:
net = net.cuda()
print(net)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.0001, momentum=0.9)
def train():
for epoch in range(epochs):
running_loss = 0.0
train_correct = 0
train_total = 0
for i,data in enumerate(train_loader,0):
inputs,train_labels = data
print(i,train_labels)
if use_gpu:
inputs,labels = Variable(inputs.cuda()),Variable(train_labels.cuda())
else:
inputs,labels = Variable(inputs), Variable(train_labels)
optimizer.zero_grad()
outputs = net(inputs)
_,train_predicted = torch.max(outputs.data,1)
train_correct += (train_predicted==labels.data).sum()
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_total += train_labels.size(0)
print('train %d epoch loss: %.3f acc: %.3f ' %(
epoch+1,running_loss/train_total,100*train_correct / train_total))
correct = 0
test_loss = 0.0
test_total = 0
net.eval()
for data in test_loader:
images,labels = data
if use_gpu:
images,labels = Variable(images.cuda()), Variable(labels.cuda())
else:
images, labels = Variable(images), Variable(labels)
outputs = net(images)
_,predicted = torch.max(outputs.data,1)
loss = criterion(outputs,labels)
test_loss += loss.item()
test_total += labels.size(0)
correct += (predicted == labels.data).sum()
print('test %d epoch loss: %.3f acc: %.3f' % (epoch+1,test_loss/test_total,100*correct/test_total))
torch.save(net,'/home/wangyunhao/dc/dc_dog.pt')
train()