第一句代码:
目的是在当前版本使用新版本功能,如2.7中使用3.6的功能
-
torch.utils.model_zoo.load_url(url, model_dir=None)
url:要下载对象的URL
model_dir:下载到本地文件地址 -
class MultiInputImages(models.ResNet):
属于类的继承,如果要编写的类是另一个现成类的特殊版本,可以采取继承父类的方法。 -
def __init__(self, block, layers, num_classes=1000, num_input_images=1): super(ResNetMultiImageInput, self).__init__(block, layers)
__init()__接受创建Resnet所需的信息,super()将父类与子类联结起来 -
resnet中常见的make_layer(),
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)