以 https://blog.csdn.net/hxxjxw/article/details/105667269 为例
源代码,运行后会生成fcnet.pth保存网络参数
import torch import torchvision from torch import nn from torch.nn import functional as F from torch import optim from utils import plot_image,plot_curve,one_hot batch_size = 512 train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('dataset/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,) ) ]) ), batch_size=batch_size,shuffle=True ) test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('dataset/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,) ) ]) ), batch_size=batch_size,shuffle=True ) #x是图片,y是label # x, y = next(iter(train_loader)) # print(x.shape, y.shape, x.min(), x.max()) # plot_image(x, y, 'image sample') class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.fc1 = nn.Linear(28*28,256) self.fc2 = nn.Linear(256,64) self.fc3 = nn.Linear(64,10) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() #k,v就分别是w和b for k, v in net.named_parameters(): print(k,v) optimizer = optim.SGD(net.parameters(),lr=0.01, momentum=0.9) train_loss = [] for epoch in range(3): for batch_idx , (x,y) in enumerate(train_loader): x = x.view(x.size(0),28*28) out = net(x) y_onehot = one_hot(y) loss = F.mse_loss(out, y_onehot) optimizer.zero_grad() loss.backward() optimizer.step() train_loss.append(loss.item()) if batch_idx % 10 == 0: print(epoch, batch_idx, loss.item()) torch.save(net.state_dict(),'fcnet.pth') plot_curve(train_loss) total_correct = 0 for x,y in test_loader: x = x.view(x.size(0), 28*28) out = net(x) pred = out.argmax(dim=1) correct = pred.eq(y).sum().float().item() total_correct += correct total_num = len(test_loader.dataset) acc = total_correct/total_num print('test accurancy:',acc)
此时还是使用原来的代码加载pth,砍掉了训练环节
import torch import torchvision from torch import nn from torch.nn import functional as F from torch import optim from utils import plot_image,plot_curve,one_hot batch_size = 512 train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('dataset/', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,) ) ]) ), batch_size=batch_size,shuffle=True ) test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('dataset/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,),(0.3081,) ) ]) ), batch_size=batch_size,shuffle=True ) #x是图片,y是label # x, y = next(iter(train_loader)) # print(x.shape, y.shape, x.min(), x.max()) # plot_image(x, y, 'image sample') class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.fc1 = nn.Linear(28*28,256) self.fc2 = nn.Linear(256,64) self.fc3 = nn.Linear(64,10) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() #k,v就分别是w和b for k, v in net.named_parameters(): print(k,v) net.load_state_dict(torch.load('fcnet.pth')) for k, v in net.named_parameters(): print(k,v) total_correct = 0 for x,y in test_loader: x = x.view(x.size(0), 28*28) out = net(x) pred = out.argmax(dim=1) correct = pred.eq(y).sum().float().item() total_correct += correct total_num = len(test_loader.dataset) acc = total_correct/total_num print('test accurancy:',acc)
可以看出其识别准确率直接就到了0.88了,这就是直接使用了预训练的模型的参数
关于加载全部模型还是只加载参数,可看 https://blog.csdn.net/u014264373/article/details/85332181
以全连接网络MNIST识别为例的模型参数保存与加载
最新推荐文章于 2023-05-15 11:28:49 发布