基于Pytorch VGG16,从头搭建
代码实现VGG
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
self.pool1 = nn.MaxPool2d(2)
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
self.pool2 = nn.MaxPool2d(2)
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
self.pool3 = nn.MaxPool2d(2)
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
self.pool4 = nn.MaxPool2d(2)
self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 1)
self.pool5 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(7 * 7 * 512, 4096)
self.fc2 = nn.Linear(4096, 4096)
self.classifier = nn.Linear(4096, 1000)
def forward(self, x):
x = F.relu(self.conv1_1(x))
x = F.relu(self.conv1_2(x))
x = self.pool1(x)
x = F.relu(self.conv2_1(x))
x = F.relu(self.conv2_2(x))
x = self.pool2(x)
x = F.relu(self.conv3_1(x))
x = F.relu(self.conv3_2(x))
x = F.relu(self.conv3_3(x))
x = self.pool3(x)
x = F.relu(self.conv4_1(x))
x = F.relu(self.conv4_2(x))
x = F.relu(self.conv4_3(x))
x = self.pool4(x)
x = F.relu(self.conv5_1(x))
x = F.relu(self.conv5_2(x))
x = F.relu(self.conv5_3(x))
x = self.pool5(x)
# print('conv5.shape',x.shape) #n*7*7*512
x = x.reshape(-1, 7 * 7 * 512)
# print('conv5.shape',x.shape) #n*7*7*512
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.classifier(x)
return x
x = torch.randn((1, 3, 224, 224))
vgg = VGG()
y = vgg(x)
print(y.shape)
torch.save(vgg, 'vgg.pth')
torch.onnx.export(vgg, x, 'vgg.onnx')
这段代码实现了一个简化版的 VGG 网络,并完成了以下主要任务:
定义 VGG 网络结构:
使用 nn.Module 定义了一个名为 VGG 的类。
在 __init__ 方法中初始化了所有卷积层、池化层和全连接层。
在 forward 方法中实现了前向传播过程,包括卷积、激活、池化和平铺操作,最后通过全连接层得到分类结果。
创建网络实例并进行前向传播:
创建一个输入张量 x,形状为 (1, 3, 224, 224)。
实例化 VGG 网络 vgg。
调用 vgg(x) 进行前向传播,得到输出张量 y。
打印输出张量 y 的形状。
保存模型和导出为 ONNX 格式:
使用 torch.save 将整个模型保存为 vgg.pth 文件。
使用 torch.onnx.export 将模型导出为 ONNX 格式的文件 vgg.onnx。
总结:
定义 VGG 网络:构建了一个包含多个卷积层、池化层和全连接层的网络。
前向传播:对输入数据进行前向传播,得到输出结果。
保存模型:将模型保存为 .pth 文件。
导出 ONNX:将模型导出为 ONNX 格式,便于在其他平台或框架中使用。