# -*- coding: utf-8 -*- # 作者:小土堆 # 公众号:土堆碎念 import torch from model_save import * # 方式1-》保存方式1,加载模型 import torchvision from torch import nn model = torch.load("vgg16_method1.pth") # print(model) # 方式2,加载模型 vgg16 = torchvision.models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # model = torch.load("vgg16_method2.pth") # print(vgg16) # 陷阱1 # class Tudui(nn.Module): # def __init__(self): # super(Tudui, self).__init__() # self.conv1 = nn.Conv2d(3, 64, kernel_size=3) # # def forward(self, x): # x = self.conv1(x) # return x model = torch.load('tudui_method1.pth') print(model)
土堆老师搭建神经网络代码
最新推荐文章于 2024-10-08 20:27:53 发布