import numpy as np
import torch
import torchvision
from torchvision import models
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from src import fcn_resnet50, resnet50
from model_fcn8s import VGG, fcn_vgg16
pretrain_backbone = True
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
backbone = VGG(num_classes=21, struct=struct)
# vgg16 = models.vgg16(pretrained=False)
if pretrain_backbone is True:
weights_dict = torch.load("vgg16-397923af.pth", map_location='cpu')
for name, param in weights_dict.items():
print(f"{name}: {param}")
backbone.load_state_dict(weights_dict, strict=False)
print("Loaded pretrained model parameters:")
for name, param in backbone.state_dict().items():
print(f"{name}: {param}")
04-06
1万+
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)