train.py:
import torch
import torch.nn as nn
import numpy as np
from torchvision import models
import torchvision.transforms as transforms
from cifar10 import CIFAR10
def genNet():
net=models.resnet34(pretrained=True)
net.conv1=nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
net.bn1=nn.BatchNorm2d(64)
net.maxpool=nn.MaxPool2d(kernel_size=1, stride=1)
net.avgpool=nn.AvgPool2d(4, stride=1)
expansion=1
net.fc=nn.Linear(512 * expansion, 10)
return net
torch.cuda.set_device(0)
gpu_en=True
load=True
save=True
num_epochs=1
nbatch=150
lr=0.001
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset=CIFAR10('D:\\dataset\\cifar10', train=True, transform=transform)
net=genNet()
if load:
net.load_state_dict(torch.load('net.pkl'))
if gpu_en:
net=net.cuda()
loader=torch.utils.data.DataLoader(dataset, batch_size=nbatch, shuffle=True, num_workers=0)
criterion=nn.CrossEntropyLoss()
optimizer=torch.o