手写体数字的自编码器(Autoencoders for handwritten digits)

链接:https://pan.baidu.com/s/1mvvjY-ErtgjAkcUkK6p_Fw 
提取码:xsa4

最近要对编码器解码器有一些学习,就从最简单的手写体数字识别切入,完成的效果还行吧,自己写完留个记录省着找不着

数据集

数据集就是直接从官方下载,batch_size我根据电脑设置为64

import torchvision as tv
from torch.utils.data import DataLoader

def get_data_loaders(batch_size=64, data_dir="dataset"):
    train_data = tv.datasets.MNIST(root=data_dir, train=True, transform=tv.transforms.ToTensor(), download=True)
    test_data = tv.datasets.MNIST(root=data_dir, train=False, transform=tv.transforms.ToTensor(), download=True)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

模型

模型的话(讲真的我不知道她到底能怎么样,编码结构就是卷积,池化,卷积,池化,全连接,解码器结构就是,反全连接,两个反卷积)模型就是这样的,PS:其实在这里我有个疑问,如果将编码器解码器拆开,那编码器是不是就可以直接进行分类,对于手写体数字识别来说,最后只要有十个特征向量就可以了,我是这么觉得的,过后我把中间全连接的十个数据打印出来看一看怎么事,模型就是这么写的,其实很简单

self.encode = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1600, 400),
            nn.ReLU(True),
            nn.Linear(400 , 10)
        )

        self.decode = nn.Sequential(
            nn.Linear(10,400),
            nn.ReLU(True),
            nn.Linear(400,1600),
            nn.Unflatten(1,(64,5,5)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=3, padding=1, output_padding=1),  # 上采样
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 上采样到原始图像大小
            nn.Sigmoid()
        )

训练

优化器和损失函数

optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)
loss_function = torch.nn.MSELoss().to(device)

 训练过程就是一个循环

    for data in train_loader:
        imgs,_ = data
        imgs = imgs.to(device)
        out_img = model(imgs)
        loss = loss_function(out_img,imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

 我让他跑50轮,咋说呢50轮的时候他也没收敛,但是看着效果还行

评估(测试)

评估的话是我找的手写体的图片

 找了3个都还可以,就是对数据预处理这块有些问题,需要将咱们自己的图片转化成模型需要的格式

transform = tv.transforms.Compose([
    tv.transforms.Grayscale(),
    tv.transforms.Resize((28,28)),
    tv.transforms.ToTensor()
])


img = torch.reshape(img,(1,1,28,28))

 最后把他送到咱们的模型里面就ok了,不过你是用matlab的那个库直接打印出来,还是用tensorboard都可以,我用的是后者。

emmm百度网盘链接放这里了,写的挺烂,对付看吧兄弟们(实现的效果还可以嘿嘿)

 要是哪块有疑问了,请留言,哪里有问题也请直接指出

 链接:https://pan.baidu.com/s/1mvvjY-ErtgjAkcUkK6p_Fw 
提取码:xsa4

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值