Model函数
模型部分代码
Model简化
由于先前的模型写的太过于复杂和离谱,这里只提供简单的一个demo(回归初心),来看看pytorch训练模型的一个基本的过程,感觉这样就够了。
import torch
import torch.nn
import numpy as np
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
import torch.optim as optimizer
注意:此处数据集在本地,因此download=False
;若需要下载的改为True
同样的,第一个参数为数据存放路径。
data_path = '../CIFAR_10_zhuanzhi/cifar10'
cifar = CIFAR10(data_path, train=True, download=False, transform=_task)
这里只是为了构造取样的角标,可根据自己的思路进行拓展。
此处使用了前百分之八十作为训练集,百分之八十到九十的作为验证集,后百分之十为测试集。
samples_count = len(cifar)
split_train = int(0.8 * samples_count)
split_valid = int(0.9 * samples_count)
index_list = list(range(samples_count))
train_idx, valid_idx, test_idx = index_list[:split_train], index_list[split_train:split_valid], index_list[split_valid:]
定义采样器:create training and validation, test sampler。
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_samlper = SubsetRandomSampler(test_idx )
# create iterator for train and valid, test dataset
trainloader = DataLoader(cifar, batch_size=256, sampler=train_sampler)
validloader = DataLoader(cifar, batch_size=256, sampler=valid_sampler)
testloader = DataLoader(cifar, batch_size=256, sampler=test_samlper )
一个简单的网络设计。
class Net(torch.nn.Module):
"""
网络设计了三个卷积层,一个池化层,一个全连接层
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = torch.nn.Conv2d(16, 32