import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import variable
from scipy import misc
from torchvision import transforms, datasets
from torchvision.utils import save_image
import os
batch_size = 64
train_data = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download= True)
test_data = datasets.MNIST(root= './data/', train=True, transform=transforms.ToTensor(), download= True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size= batch_size, shuffle= True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size= batch_size, shuffle=True)
if not os.path.exists('./output_image'):
os.mkdir('./output_image')
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
def Load_image():
tif_data = misc.imread('/Users/changxingya/Downloads/UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train/Train001/001.tif')
tif_data_tensor = torch.from_numpy(tif_data)
print(tif_data_tensor.size())
return tif_data_tensor
class TemporalDetection(nn.Module):
def __init__(self):
super(TemporalDetection, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=3),
nn.ReLU(True),
nn.Conv2d(16, 32, 3, stride=2),
nn.ReLU(True)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(32, 16, 5, stride=3),
nn.ReLU(True),
nn.ConvTranspose2d(16, 1, 2, stride=2),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
model = TemporalDetection()
print(model)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.MSELoss()
def train(epoch):
for i, (data,target) in enumerate(train_loader):
output = model(data)
loss = loss_function(output, data)
print("Temporal Detection Loss = {}".format(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
img = to_img(output)
save_image(img, './output_image/image_{}.png'.format(i))
for i in range(1, 10):
#train_data = Load_image()
train(i)
print("Done")
pytorch mnist的auto encoder
最新推荐文章于 2024-01-09 16:45:53 发布