Python学习笔记——自动编码器

无监督学习

自动编码器

import torch
from torch.utils.data import DataLoader
from  torchvision import transforms,datasets
from ae import AE
from torch import nn,optim
import visdom
def main():
    mnist_train=datasets.MNIST('mnist',True,transform=transforms.Compose([
        transforms.ToTensor()
    ]),download=True)
    mnist_train=DataLoader(mnist_train,batch_size=32,shuffle=True)
    mnist_test= datasets.MNIST('mnist', True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test= DataLoader(mnist_test, batch_size=32, shuffle=True)
    x,_=iter(mnist_train).next()
    print('x:',x.shape)

    device=torch.device('cuda')
    model=AE().to(device)
    criteon=nn.MSELOSS()
    optimizer=optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    viz=visdom.Visdom()
    for epoch in range(1000):
        for batchidx,(x,_) in enumerate(mnist_train):
            #[b,1,28,28]
            x=x.to(device)
            x_hat=model(x)
            loss=criteon(x_hat,x)

            ##backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch,'loss:',loss.item())
        x,_=iter(mnist_test).next()
        x=x.to(device)
        with torch.no_grad():
            x_hat=model(x)
        viz.images(x,nrow=8,win='x',opts=dict(title='x'))
        viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))


if __name__=='__main__':
    main()
import torch
from torch import nn

class AE(nn.Module):


    def __init__(self):
        super(AE,self).__init__()

##[b,784]=>[b,20]
        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        #[b,20]=>[b,784]
        self.decoder=nn.Sequential(
            nn.Linear(20,64),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()

        )

    def forward(self,x):
        '''

        :param self:
        :param x:
        :return:
        '''
        batchsz=x.size(0)
        #flatten
        x=x.view(batchsz,784)
        #encoder
        x=self.encoder(x)
        #decoder
        x=self.decoder(x)
        #reshape
        x=x.view(batchsz,1,28,28)

        return x

变分自动编码器(VAE)

import torch
from torch import nn

class VAE(nn.Module):


    def __init__(self):
        super(VAE,self).__init__()

##[b,784]=>[b,20]
        #u:[b,10]
        #sigma:[b,10aa]


        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        #[b,20]=>[b,784]
        self.decoder=nn.Sequential(
            nn.Linear(10,64),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()

        )

    def forward(self,x):
        '''

        :param self:
        :param x:
        :return:
        '''
        batchsz=x.size(0)
        #flatten
        x=x.view(batchsz,784)
        #encoder
        h=self.encoder(x)
        #[b,20],including mean and sigma
        h=self.encoder(x)
        #[b,20]=>[b,10]and [b,10]
        mu,sigma=h_.chunk(2,dim=1)
        #reparametrize trick, epison~N(0,1)
        ##正态分布,生成一个可导的sample操作
        h=mu+sigma*torch.randn_like(sigma)
        #
        kld=0.5*torch.sum(
            torch.pow(mu,2)+
            torch.pow(sigma,2)-
            torch.log(1e-8+torch.pow(sigma,2))-1
        )/np.prod(x.shape)


        #decoder
        x=self.decoder(x)
        #reshape
        x=x.view(batchsz,1,28,28)

        return x_hat, kld
import torch
from torch import nn

class VAE(nn.Module):


    def __init__(self):
        super(VAE,self).__init__()

##[b,784]=>[b,20]
        #u:[b,10]
        #sigma:[b,10aa]


        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        #[b,20]=>[b,784]
        self.decoder=nn.Sequential(
            nn.Linear(10,64),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()

        )

    def forward(self,x):
        '''

        :param self:
        :param x:
        :return:
        '''
        batchsz=x.size(0)
        #flatten
        x=x.view(batchsz,784)
        #encoder
        h=self.encoder(x)
        #[b,20],including mean and sigma
        h=self.encoder(x)
        #[b,20]=>[b,10]and [b,10]
        mu,sigma=h_.chunk(2,dim=1)
        #reparametrize trick, epison~N(0,1)
        ##正态分布,生成一个可导的sample操作
        h=mu+sigma*torch.randn_like(sigma)
        #
        kld=0.5*torch.sum(
            torch.pow(mu,2)+
            torch.pow(sigma,2)-
            torch.log(1e-8+torch.pow(sigma,2))-1
        )/(batchsz*28*28)

        #decoder
        x=self.decoder(x)
        #reshape
        x=x.view(batchsz,1,28,28)

        return x_hat, kld
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值