Pytorch -> 自定义数据集 + 迁移学习 + AE

自定义数据集必须定义的三个函数

class Pokemon(Dataset):
    # 划分训练与测试集
    def __init__(self, training=True):
        if training:
            self.samples = list(range(1,1001))
        else:
            self.samples = list(range(1001,1501))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

其余见代码文件

主要思路:

  1. 定义数据集(引入参数、划分数据集、索引、生成图片与标签)
  2. 生成CSV文件(生成images列表、写入CSV、读取CSV)
  3. 预处理(设置预处理参数和visdom)
  4. 迁移学习

自编码器

 让p尽可能大

class AE(nn.Module):

    def __init__(self):
        super(AE, self).__init__()


        # [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )


    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x, None
class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()


        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """

        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        # μ和σ
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)
        # 计算KL函数,1e-8是为了限幅
        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets

from    ae import AE
from    vae import VAE

import  visdom

def main():
    mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)


    mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)


    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')
    # model = AE().to(device)
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):


        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = - loss - 1.0 * kld
                loss = - elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

  【深度学习】 自编码器(AutoEncoder) - 知乎

VAE(Variational Autoencoder)简单推导及理解_cjh_jinduoxia的博客-CSDN博客_vae推导

个人理解的VAE:原图采样成z,将z(符合标准正态分布)的结果映射到x(符合高斯分布μσ),目标是是logP(x)尽可能大

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
UnsatisfiableError: The following specifications were found to be incompatible with the existing python installation in your environment: Specifications: - torchaudio -> python[version='2.7.*|3.5.*|3.6.*|>=2.7,<2.8.0a0|>=3.5,<3.6.0a0|3.4.*'] Your python: python=3.10 If python is on the left-most side of the chain, that's the version you've asked for. When python appears to the right, that indicates that the thing on the left is somehow not available for the python version you are constrained to. Note that conda will not change your python version to a different minor version unless you explicitly specify that. The following specifications were found to be incompatible with each other: Output in format: Requested package -> Available versions Package pytorch-cuda conflicts for: pytorch -> pytorch-cuda[version='>=11.6,<11.7|>=11.7,<11.8|>=11.8,<11.9'] torchvision -> pytorch==2.0.1 -> pytorch-cuda[version='>=11.6,<11.7|>=11.7,<11.8|>=11.8,<11.9'] torchvision -> pytorch-cuda[version='11.6.*|11.7.*|11.8.*'] torchaudio -> pytorch-cuda[version='11.6.*|11.7.*|11.8.*'] torchaudio -> pytorch==2.0.1 -> pytorch-cuda[version='>=11.6,<11.7|>=11.7,<11.8|>=11.8,<11.9'] Package requests conflicts for: python=3.10 -> pip -> requests torchvision -> requests Package pytorch conflicts for: torchaudio -> pytorch[version='1.10.0|1.10.1|1.10.2|1.11.0|1.12.0|1.12.1|1.13.0|1.13.1|2.0.0|2.0.1|1.9.1|1.9.0|1.8.1|1.8.0|1.7.1|1.7.0|1.6.0'] torchvision -> pytorch[version='1.10.0|1.10.1|1.10.2|1.11.0|1.12.0|1.12.1|1.13.0|1.13.1|2.0.0|2.0.1|1.9.1|1.9.0|1.8.1|1.8.0|1.7.1|1.7.0|1.6.0|1.5.1'] Package msvc_runtime conflicts for: torchvision -> python[version='>=3.5,<3.6.0a0'] -> msvc_runtime pytorch -> python[version='>=3.5,<3.6.0a0'] -> msvc_runtime Package setuptools conflicts for: python=3.10 -> pip -> setuptools pytorch -> jinja2 -> setuptools torchvision -> setuptools什么意思
06-09

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值