pytorch基于cifar10实现NIN网络
模型Network in network
train.py
#导出需要的库
import torch
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import data
import cPickle as pickle
import numpy
from torch.autograd import Variable
#给定数据集路径,加载数据
trainset = data.dataset(root='./data', train=True)
#该接口会将dataset根据batchsize大小,是否shuffle等封装成一个batchsize大小的tensor
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
#测试数据读取
testset = data.dataset(root='./data', train=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
#数据集总类别数
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#定义模型
class Net(nn.Module):#继承父类nn.Module
def __init__(self):
super(Net, self).__init__()#super可以指代父类而不需要显式的声明,这对更改基类(此处为__init__())的时候是有帮助的,使得代码更容易维护
self.classifier = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),