通过已有数据的导入进行训练
导入数据与数据增强
import torchvision
train_data=torchvision.dataset.CIFAR10(
root='./',
train=True,
transform=transform_train,
down=True)
#数据增强
transoform_train=torchvision.transforms.Compose([
transforms.RandomCrop(32,padding=4),
transforms.RandomHorizontalFlip(),
transfoems.RandomRotation(10),
transforms.Totensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
构建模型
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder,self).__init__()
self.encoder=nn.Sequential()
self.decoderr=nn.Sequential()
def forward(self,x):
return encode,decode
加载模型和数据
import torch.utils.data as Data
auto=AutoEncoder()
#加载批次数据
train_data_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
optimizer=torch.optim.Adam(auto.parameters(),lr=LR)
loss_func=nn.MSELoss()
训练
for epoch in range(EPOCH):
for i,(x,y) in enumerate(train_data_loade):
encode,decode=auto(x)
loss=loss_func(decode,x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('epoch',epoch,'loss',loss)