pytorch学习之路——09

autoencoder:分为编码和解码两个阶段,编码阶段相当于提取特征,解码可以看成编码的"逆变换",最后的输出维度与输入维度一致,损失函数选择最小均方误差。注意在autoencoder中没有使用标签信息。

linear_autoencoder:

import torch
from torch import nn,optim
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data.dataloader import DataLoader


def get_dataloader(batch_size):
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = datasets.MNIST('MNIST_data/',train=True,download=True,transform=transform)
    testset = datasets.MNIST('MNIST_data/',train=False,download=True,transform=transform)

    trainLoader = DataLoader(trainset,batch_size=batch_size)
    testLoader = DataLoader(testset,batch_size=batch_size)

    return trainLoader,testLoader


class AutoEncoder(nn.Module):
    def __init__(self,fea_dim):
        super().__init__()
        # encoder
        self.fc1 = nn.Linear(784,fea_dim)
        # decoder
        self.fc2 = nn.Linear(fea_dim,784)

    def forward(self,x):
        encoder_x = F.relu(self.fc1(x))
        decoder_x = self.fc2(encoder_x)

        return torch.sigmoid(decoder_x)


def train(trainLoader,epoch,fea_dim):
    model = AutoEncoder(fea_dim)
    costFunc = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(),lr=0.001)

    for e in range(epoch):
        trainLoss = 0
        for imgs,_ in trainLoader:
            imgs = imgs.view(imgs.shape[0],-1)
            optimizer.zero_grad()
            output = model.forward(imgs)
            loss = costFunc(output,imgs)
            loss.backward()
            optimizer.step()
            trainLoss += loss.item()*imgs.shape[0]

        trainLoss /= len(trainLoader)
        print("Epoch: {}/{},".format(e,epoch),
              "Train Loss: {:.3f}".format(trainLoss))


if __name__ == "__main__":
    epoch = 10
    fea_dim = 32
    batch_size = 20
    valid_size = 0.1
    trainLoader,testLoader = get_dataloader(batch_size)
    train(trainLoader,epoch,fea_dim)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值