从零搭建VGGNet

基于Pytorch VGG16,从头搭建

代码实现VGG

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.pool1 = nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
        self.pool2 = nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
        self.pool3 = nn.MaxPool2d(2)

        self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
        self.pool4 = nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 1)
        self.pool5 = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(7 * 7 * 512, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.classifier = nn.Linear(4096, 1000)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.pool1(x)

        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.pool3(x)

        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.pool4(x)

        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.pool5(x)

        # print('conv5.shape',x.shape) #n*7*7*512
        x = x.reshape(-1, 7 * 7 * 512)
        # print('conv5.shape',x.shape) #n*7*7*512

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.classifier(x)

        return x


x = torch.randn((1, 3, 224, 224))
vgg = VGG()
y = vgg(x)
print(y.shape)

torch.save(vgg, 'vgg.pth')
torch.onnx.export(vgg, x, 'vgg.onnx')

这段代码实现了一个简化版的 VGG 网络,并完成了以下主要任务:
定义 VGG 网络结构:
使用 nn.Module 定义了一个名为 VGG 的类。
在 __init__ 方法中初始化了所有卷积层、池化层和全连接层。
在 forward 方法中实现了前向传播过程,包括卷积、激活、池化和平铺操作,最后通过全连接层得到分类结果。
创建网络实例并进行前向传播:
创建一个输入张量 x,形状为 (1, 3, 224, 224)。
实例化 VGG 网络 vgg。
调用 vgg(x) 进行前向传播,得到输出张量 y。
打印输出张量 y 的形状。
保存模型和导出为 ONNX 格式:
使用 torch.save 将整个模型保存为 vgg.pth 文件。
使用 torch.onnx.export 将模型导出为 ONNX 格式的文件 vgg.onnx。
总结:
定义 VGG 网络:构建了一个包含多个卷积层、池化层和全连接层的网络。
前向传播:对输入数据进行前向传播,得到输出结果。
保存模型:将模型保存为 .pth 文件。
导出 ONNX:将模型导出为 ONNX 格式,便于在其他平台或框架中使用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值