pytorch mnist的auto encoder

import torch

import torch.nn.functional as F

import torch.nn as nn

import torch.optim as optim

from torch.autograd import variable

from scipy import misc

from torchvision import transforms, datasets

from torchvision.utils import save_image

import os

batch_size = 64

train_data = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download= True)

test_data = datasets.MNIST(root= './data/', train=True, transform=transforms.ToTensor(), download= True)

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size= batch_size, shuffle= True)

test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size= batch_size, shuffle=True)

if not os.path.exists('./output_image'):

    os.mkdir('./output_image')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

def Load_image():

    tif_data = misc.imread('/Users/changxingya/Downloads/UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train/Train001/001.tif')

    tif_data_tensor = torch.from_numpy(tif_data)

    print(tif_data_tensor.size())

    return tif_data_tensor

class TemporalDetection(nn.Module):

    def __init__(self):

        super(TemporalDetection, self).__init__()

        self.encoder = nn.Sequential(

            nn.Conv2d(1, 16, 3, stride=3),

            nn.ReLU(True),

            nn.Conv2d(16, 32, 3, stride=2),

            nn.ReLU(True)
        )

        self.decoder = nn.Sequential(

            nn.ConvTranspose2d(32, 16, 5, stride=3),

            nn.ReLU(True),

            nn.ConvTranspose2d(16, 1, 2, stride=2),

            nn.Tanh()
        )

    def forward(self, x):

        x = self.encoder(x)

        x = self.decoder(x)

        return x

model = TemporalDetection()

print(model)

optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_function = nn.MSELoss()

def train(epoch):

    for i, (data,target) in enumerate(train_loader):

        output = model(data)

        loss = loss_function(output, data)

        print("Temporal Detection Loss = {}".format(loss))

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        if i % 10 == 0:

            img = to_img(output)

            save_image(img, './output_image/image_{}.png'.format(i))

for i in range(1, 10):

    #train_data = Load_image()

    train(i)

    print("Done")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值