import torch from torch import nn from torch.utils.tensorboard import SummaryWriter from torchvision import transforms,datasets,utils from torch.utils.data import DataLoader import torch.optim as optim import numpy as np import os class AEEncoder(nn.Module): def __init__(self,input_features,output_features,hidden_dims=None): super(AEEncoder, self).__init__() self.input_features=input_features if hidden_dims is None: hidden_dims=[512,256,128] layers=[] for h_dim in hidden_dims: layers.append( nn.Sequential( nn.Linear(in_features=input_features,out_features=h_dim), nn.BatchNorm1d(num_features=h_dim), nn.PReLU() ) ) input_features=h_dim layers.append(nn.Linear(input_features,output_features)) self.encoder=nn.Sequential(*layers) def forward(self,x): x=x.view(-1,self.input_features) return self.encoder(x) class AEDecoder(nn.Module): def __init__(self,input_features,output_features,hidden_dims=None): super(AEDecoder, self).__init__() if hidden_dims is None: hidden_dims=[128,256,512] layers = [] for h_dim in hidden_dims: layers.append( nn.Sequential( nn.Linear(in_features=input_features, out_features=h_dim), nn.BatchNorm1d(num_features=h_dim), nn.PReLU() ) ) input_features = h_dim layers.append(nn.Linear(input_features, output_features)) self.decoder = nn.Sequential(*layers) def forward(self,x): return self.decoder(x) class AE(nn.Module): def __init__(self): super(AE, self).__init__() self.encoder=AEEncoder(784,32) self.decoder=AEDecoder(32,784) def forward(self,x): z=self.encoder(x) z=self.decoder(z) return z.view(-1,1,28,28) if __name__ == '__main__': _batch_size=16 _total_epoch=5 # writer=SummaryWriter(log_dir='./output/ae') # writer.add_graph(net,torch.empty(10,1,28,28)) # torch.save(net,'./output/ae/model.pth') transform=transforms.Compose([ transforms.ToTensor(), transforms.RandomResizedCrop(size=(28,28),scale=(0.9,1.0)) ]) trainset=datasets.MNIST(root='../data/MNIST',train=True,download=True,transform=transform) trainloader=DataLoader(dataset=trainset,batch_size=_batch_size,shuffle=True) testset = datasets.MNIST(root='../data/MNIST', train=False, download=True, transform=transform) testloader = DataLoader(dataset=testset, batch_size=_batch_size, shuffle=False) net = AE() loss_fn=nn.MSELoss() opt=optim.SGD(net.parameters(),lr=0.01,momentum=0.7) train_step=0 test_step=0 summary_step_interval=100 for epoch in range(_total_epoch): net.train(True) for data in trainloader: inputs,labels=data outputs=net(inputs) _loss=loss_fn(outputs,inputs) opt.zero_grad() _loss.backward() opt.step() if train_step%summary_step_interval==0: print(f'Train {epoch+1}/{_total_epoch} {train_step}' f'loss:{_loss.item():.3f}') train_step+=1 net.eval() for data in testloader: inputs,labels=data outputs=net(inputs) _loss=loss_fn(outputs,inputs) if test_step%summary_step_interval==0: print(f'Test {epoch+1}/{_total_epoch} {test_step}' f'loss:{_loss.item():.3f}') test_step+=1 net.eval() idx=np.random.randint(0,len(testset)) img,label=testset[idx] img=img[None,...] img2=net(img) img=torch.cat([img,img2],dim=0) _dir=f'./output/ae/image' if not os.path.exists(_dir): os.makedirs(_dir) utils.save_image(img,f'{_dir}/{epoch}_{label}.png')
编解码网络手写数字图片生成
最新推荐文章于 2023-12-07 17:24:04 发布