pytorch-AutoEncoders实战

1. AutoEncoders回顾

如下图:AutoEncoders实际上就是重建自己的过程
在这里插入图片描述

2. 实现网络结构

创建类继承自nn.Model,并实现init和forward函数,init中实现encoder、decoder
直接上代码,其中encoder维度变化784->256->64->20,decoder是相反过程

import  torch
from    torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        # [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x, None

3. 实现main函数

  • 加载minist数据集
  • 初始化网络
  • 初始化loss函数
  • 初始化优化器
  • 循环迭代数据输入模型、计算loss、优化器优化参数
    完整代码
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
#from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)


    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    model = AE().to(device)
    #model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, _ = model(x)
            loss = criteon(x_hat, x)
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()
根据引用\[1\]和引用\[2\]的内容,你可以使用以下命令在虚拟环境中安装PyTorch和相关库: ``` mamba install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia ``` 这个命令会安装PyTorch、torchvision和torchaudio,并指定使用CUDA 11.7版本。同时,它会从pytorch和nvidia的频道中获取软件包。 然而,根据引用\[3\]的内容,如果你在指定的镜像源中找不到指定版本的PyTorch,可能会导致安装的是CPU版本而不是GPU版本。为了解决这个问题,你可以尝试使用其他镜像源或者手动指定安装GPU版本的PyTorch。 综上所述,你可以尝试使用以下命令来安装PyTorch和相关库,并指定使用CUDA 11.7版本: ``` mamba install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia ``` 希望这能帮到你! #### 引用[.reference_title] - *1* [三分钟搞懂最简单的Pytorch安装流程](https://blog.csdn.net/weixin_44261300/article/details/129643480)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Pytorch与NVIDA驱动控制安装](https://blog.csdn.net/m0_48176714/article/details/129311194)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [解决使用conda下载pytorch-gpu版本困难的问题](https://blog.csdn.net/qq_41963301/article/details/131070422)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值