小白第一篇源码阅读笔记,理解一个十分简单的VGG网络源码,有助于新手入门pytorch,若有错误请多多指教。
先上源码链接
本文除了对源码的顺序和注释进行了一点删改外,其他与源码一致
首先看入口,也就是当我要去创建一个VGG模型时,我会调用的那个函数:
def vgg11(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
def vgg13(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
def vgg16(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained=False, progress=True, **kwargs):
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
def vgg19(pretrained=False, progress=True, **kwargs):
return _vgg('vgg19'<