对比自己的模型结构和预训练加载的模型结构是否一致

import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from torchvision.models import vgg16
from model_fcn8s import VGG

# 期望的FCN模型实例
struct = [(2, 64), (2, 128), (3, 256), (3, 512), (3, 512)]
expected_model = VGG(num_classes=21, struct=struct)

# 加载预训练的VGG16模型
pretrained_vgg16 = vgg16(pretrained=True)
pretrained_vgg16 = pretrained_vgg16.features

# 对比期望的模型结构和加载的模型结构
for expected_param, loaded_param in zip(expected_model.parameters(), pretrained_vgg16.parameters()):
    assert expected_param.shape == loaded_param.shape, "Parameter shape mismatch!"

print("Model structure matches the expected structure.")
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值