#搭建网络结构
import torch
import torch.nn as nn #卷积
import torch.nn.functional as F #softmax
#网络提供 init 和 forward 两个参数
class VGGbase(nn.Module):
def __init__(self):
super(VGGbase, self).__init__() #采用super进行网络的初始化 意思是自己搭建的网络Net会继承nn.Module:
self.conv1 = nn.Sequential( #第一个卷积采用序列
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),#输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(64), #传入输出的参数的数量 转化为
nn.ReLU() #完成非线性的变化
)
self.max_pooling1 = nn.MaxPool2d(kernel_size=2,stride=2) #池化层 降低采样防止过拟合
# pooling 之后图像下降了2倍 但channel要翻倍
self.conv2_1 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(128), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.conv2_2 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(128), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2) # 池化层 降低采样防止过拟合
#7*7
self.conv3_1 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(256), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.conv3_2 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(256), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) # 池化层 降低采样防止过拟合 padding=1 那么输出就不是3*3而是4*4了
#4*4
self.conv3_1 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(512), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.conv3_2 = nn.Sequential( # 第二个卷积 多用几个
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), # 输入是3轨道 输出是64轨道 卷积核尺寸3 步长1 图像填充+1 可以让输出输出的图保持一致
nn.BatchNorm2d(512), # 传入输出的参数的数量 转化为
nn.ReLU() # 完成非线性的变化
)
self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2) # 池化层 降低采样防止过拟合 padding=1 那么输出就不是3*3而是4*4了
#2*2
#fc层就是 全连接层 batchsize * 512 *2 *2 -> batchsize *(512 *2 *2)
self.fc =nn.Linear(512*4,10) #线性层 channel*大小(2*2=4) 输出10维对应10个类别
def forward(self,x): #在forward里面进行对输入的网络处理 串联起来构造分类网络
batchsize =x.size(0) #取出第0个维度
out =self.conv1(x) #取第一卷积
out =self.max_pooling1(out) #进行第一个pooling 注意这里上一层输出就是下一层的输入
out = self.conv2_1(out) # 取第一卷积
out = self.conv2_2(out)
out = self.max_pooling2(out)
out = self.conv3_1(out) # 取第一卷积
out = self.conv3_2(out)
out = self.max_pooling3(out)
out = self.conv4_1(out) # 取第一卷积
out = self.conv4_2(out)
out = self.max_pooling4(out)
out = out.view(batchsize,-1) #通过view函数进行展平 batchsize*n -1表示自动补齐
# batchsize * c* h* w -->batchsize *n
out = self.fc(out)
out = F.log_softmax(out ,dim =1)
return out
def VGGNet():
return VGGbase()
Pythorch VGG网络搭建 句句注释
最新推荐文章于 2024-05-15 21:54:47 发布