搭建神经网络
CIFAR10 Model
推算公式
损失函数
- 计算实际输出与目标函数之间的差距
- 为更新输出提供一定的依据(反向传播,比如梯度)
代码
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear, CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10(root="../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset, batch_size=64)
class MyCNN(nn.Module):
def __init__(self):
super(MyCNN, self).__init__()
self.conv = Sequential(
Conv2d(3, 32, kernel_size=5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, kernel_size=5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, kernel_size=5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, input):
output = self.conv(input)
return output
mycnn = MyCNN()
# 损失函数 & 优化器
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(mycnn.parameters(), lr=0.01)
for epoch in range(20):
running_loss = 0
for data in dataloader:
imgs, target = data
ouput = mycnn(imgs)
result_loss = loss(ouput, target)
optim.zero_grad()
result_loss.backward()
optim.step()
running_loss += result_loss
print(running_loss)
# # 用来验证模型
# input = torch.ones((64, 3, 32, 32))
# output = mycnn(input)
# print(output.shape)
# # 计算图
# writer = SummaryWriter("torch_seq")
# writer.add_graph(mycnn, input)
# writer.close()
计算图
模型的加载与保存
代码给出说明(保存方式与加载方式需要一一对应)
import os
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torchvision.models import VGG16_Weights
os.environ['TORCH_HOME'] = './torch_cache'
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
train_data = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
vgg16_true.classifier.add_module('add_Linear', Linear(1000, 10))
vgg16_false = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
# print(vgg16_false)
# print(vgg16_true)
# 保存方式1 模型结构+模型参数
torch.save(vgg16_false, "vgg_method1.pth")
# 保存方式2 模型参数(官方推荐)
torch.save(vgg16_false.state_dict(), "vgg_method2.pth")
# 加载方式1
model = torch.load("vgg_method1.pth")
# 加载方式2
model2 = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
model2.state_dict(torch.load("vgg_method2.pth"))