#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# @Time : 2020/2/25 下午12:43
# @Author : MJ
import torch as t
import torch.nn as nn
import math
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.nn import functional as F
import torch.optim as optim
import torchvision as tv
# from torch.autograd import Variable
#define model
class VGG(nn.Module):
def __init__(self,features,num_classes=10):
super(VGG, self).__init__()
# 网络结构(仅包含卷积层和池化层,不包含分类器)
self.features = features
self.classifer = nn.Sequential(
#fc6
nn.Linear(512,4096),
nn.ReLU(),
nn.Dropout(),
#fc7
nn.Linear(4096,4096),
nn.ReLU(),
nn.Dropout(),
#fc8
nn.Linear(4096,num_classes))
#初始化权重
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0),-1)
x =self.classifer(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0,math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m,nn.BatchNorm2d):
pytorch--实现vgg
最新推荐文章于 2022-11-08 11:06:11 发布