import torchvision cv_version = torchvision.__version__
import torch version = torch.__version__ version_m = version.split('.')
改了一些东西,具体表现在
如何把基于Pytorch 0.3.0的代码改成适用于0.4.0的?
https://www.zhihu.com/question/287154766?sort=created
Pytorch常见问题、PyTorch 0.4新版本升级指南no_grad、cuda(async=True)
https://blog.csdn.net/qq_18644873/article/details/88216904
PyTorch 0.4新版本 升级指南 no_grad
https://blog.csdn.net/jacke121/article/details/80597759
# -*-coding:utf-8-*-
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from MyDataset import MyDataset
from shufflenetv2_inference import *
from efficientnet_inference import *
from mobilenetv3_inference import *
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
version = torch.__version__
version_m = version.split('.')
if (int(version_m[0]) == 0):
if (int(version_m[1]) < 4):
from torch.autograd import Variable
import torch._utils
try:
torch._utils._rebuild_tensor_v2
except AttributeError:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
tensor.requires_grad = requires_grad
tensor._backward_hooks = backward_hooks
return tensor
torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
print(version)
if (int(version_m[1]) >= 4):
print(version)
else:
print(version)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
# for batch_idx, (data, target) in enumerate(train_loader):
for batch_idx, data_ynh in enumerate(train_loader):
# 获取图片和标签
data, target = data_ynh
data, target = Variable(data.cuda(device)), Variable(target.cuda(device))
optimizer.zero_grad()
output = model(data)
output1 = torch.nn.functional.log_softmax(output, dim=1)
loss = F.nll_loss(output1, target)
# loss = F.l1_loss(output, target)
loss.backward()
optimizer.step()
# new ynh
# 每10个batch画个点用于loss曲线
if batch_idx % 10 == 0:
niter = epoch * len(train_loader) + batch_idx
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data[0]))
def test(args, model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
torch.volatile = True
# for data, target in test_loader:
for data_ynh in test_loader:
# 获取图片和标签
data, target = data_ynh
data, target = Variable(data.cuda(device)), Variable(target.cuda(device))
output = model(data)
output1 = torch.nn.functional.log_softmax(output, dim=1)
test_loss += F.nll_loss(output, target, size_average=False).data[0]
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=300000, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = 0
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# -------------------------------------------- step 1/5 : 加载数据 -------------------------------------------
train_txt_path = './mnist_valid.txt'
valid_txt_path = './mnist_valid.txt'
# 数据预处理设置
trainTransform = transforms.Compose([
transforms.Resize((224), interpolation=2),
#transforms.RandomCrop(224, padding=4),
#transforms.Grayscale(3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.486, 0.458, 0.408],
std=[0.229, 0.224, 0.225])
])
validTransform = transforms.Compose([
transforms.Resize((224), interpolation=2),
#transforms.Grayscale(3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.486, 0.458, 0.408],
std=[0.229, 0.224, 0.225])
])
# 构建MyDataset实例 img_path是一种可在txt图片路径前面加入的一种机制
train_data = MyDataset(img_path='', txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(img_path='', txt_path=valid_txt_path, transform=validTransform)
# 构建DataLoder
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_data, batch_size=args.batch_size)
# efficientnet_b0
# efficientnet_b1
# efficientnet_b2b
# efficientnet_b3b
# mobilenetv3_large_w1
# mobilenetv3_small_w1
# shufflenetv2_w1
#print(args.num_classes)
model = efficientnet_b0(num_classes=args.num_classes, pretrained=True).cuda()
print("-------------------------------------------")
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, valid_loader, epoch)
if (args.save_model):
torch.save(model.state_dict(), "cnpk.pt")
if __name__ == '__main__':
main()