pytorch中保存网络和提取网络

保存网络:

# 保存全部网络
torch.save(网络, 网络名)
# 只保存网络参数
torch.save(网络.state_dict(), 网络名)
  • 保存整个网络,不需要再搭建结构;只保存网络参数需要在搭建之前一样的网络结构,再将参数放进去。就好比前者是去饭店买来一碗色香俱全的酸菜鱼,后者是老板加什么调料,多少调料,煮多久都告诉你,你回家自己做,做完和直接买来的一样了。
  • 据说只保存网络参数会快一点哦。

提取网络:

# 对应的提取整个网络
torch.load(网络名)
# 对用提取网络参数
网络.load_state_dict(torch.load(网络名))

以之前线性回归代码为例,用保存的网络,比较用两种方法提取的网络

import torch
import matplotlib.pyplot as plt


def save_all_net(net, net_name):
    """保存整个网络"""
    torch.save(net, net_name)


def save_net_parameters(net, net_name):
    """只保存网络中的参数"""
    torch.save(net.state_dict(), net_name)


def restore_net(net_name):
    """提取整个模型"""
    net = torch.load(net_name)
    return net


def restore_parameters(network, net_name):
    """提取网络中的参数"""
    network.load_state_dict(torch.load(net_name))

## 数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=-1)
y = x.pow(2)


def orignal_net():
    """原始网络"""
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
    loss_funcation = torch.nn.MSELoss()
    ## 训练
    for epoch in range(100):
        pridect = net(x)
        loss = loss_funcation(pridect, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), pridect.data.numpy(), 'r-', lw=5)
    # 保存模型
    save_all_net(net, "net.pkl")  # 保存整个模型
    save_net_parameters(net, "net_params.pkl")   # 只保存模型参数


def read_all_net():
    """恢复提取整个的网络"""
    net2 = restore_net("net.pkl")
    pridect = net2(x)

    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), pridect.data.numpy(), 'r-', lw=5)


def read_parameters():
    """恢复只提取参数的网络"""
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    )
    restore_parameters(net3, "net_params.pkl")
    pridect = net3(x)

    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), pridect.data.numpy(), 'r-', lw=5)
    plt.show()


# 主函数
orignal_net()
read_all_net()
read_parameters()
可视化图:证明一样

在这里插入图片描述

GAN(Generative Adversarial Network)是一种生成式深度学习模型,它由两个神经网络组成:一个生成器网络和一个判别器网络。生成器网络可以生成逼真的图像、文本或音频等,而判别器网络则用于区分生成器生成的图像与真实图像的不同之处。两个网络不断地相互对抗、优化,直到生成的图像与真实图像无法区分。 下面是一个用 PyTorch 实现的简单的 GAN 模型: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as dset import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.utils import save_image # 定义生成器网络 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc = nn.Sequential( nn.Linear(100, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, x): x = self.fc(x) return x # 定义判别器网络 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.fc = nn.Sequential( nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): x = self.fc(x) return x # 加载 MNIST 数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) train_dataset = dset.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # 初始化生成器和判别器 G = Generator() D = Discriminator() # 定义优化器和损失函数 G_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999)) D_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999)) criterion = nn.BCELoss() # 训练 GAN 模型 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') G.to(device) D.to(device) for epoch in range(100): for i, (images, _) in enumerate(train_loader): batch_size = images.size(0) images = images.view(batch_size, -1).to(device) # 训练判别器 real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) noise = torch.randn(batch_size, 100).to(device) fake_images = G(noise) D_real_outputs = D(images) D_fake_outputs = D(fake_images.detach()) D_real_loss = criterion(D_real_outputs, real_labels) D_fake_loss = criterion(D_fake_outputs, fake_labels) D_loss = D_real_loss + D_fake_loss D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # 训练生成器 noise = torch.randn(batch_size, 100).to(device) fake_images = G(noise) D_fake_outputs = D(fake_images) G_loss = criterion(D_fake_outputs, real_labels) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() # 输出损失值 if i % 100 == 0: print(f'Epoch [{epoch+1}/{100}] Batch [{i+1}/{len(train_loader)}] D_loss: {D_loss.item():.4f}, G_loss: {G_loss.item():.4f}') # 保存生成的图像 with torch.no_grad(): noise = torch.randn(64, 100).to(device) fake_images = G(noise).view(64, 1, 28, 28) save_image(fake_images, f'./gan_images/{epoch+1}.png') ``` 在这个例子,我们使用了 PyTorch 内置的 MNIST 数据集,并定义了一个含有三个全连接层的生成器网络和一个含有两个全连接层的判别器网络。我们采用了 Adam 优化器和二元交叉熵损失函数。在训练过程,我们不断地交替训练生成器和判别器,并且每完成一个 epoch 就保存一批生成的图像。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值