pytorch+VGG16卷积自编码器(用来训练图形生成器)
支持断点重训
支持动态学习率
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 21 15:21:29 2021
@author: HUANGYANGLAI
"""
from time import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import itertools
CNN_embed_dim = 512 # latent dim extracted by 2D CNN
img_x, img_y = 224, 224 # resize video 2d frame size(可能更改图片尺寸)
dropout_p = 0 # dropout probability(随机失活)
#训练参数
k = 2 # 这里没用
epochs = 5000 # (迭代次数)
batchsize = 8 #(批处理)
learning_rate = 0.00005#本设计汇总可使用动态的方式来调整学习率
log_interval = 10
flag=True#vgg官方数据初始化数据
resume=True#是否断点重新连接
###############################################################################进行自定义数据加载###############################
root='D://sign_first//video//'
print(root)
#定义读取文件的格式
def default_loader(path):
return Image.open(path).convert('RGB')#路径必须指名哪一张图,不能是指定文件夹
#创建自己的类: MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(Dataset):
#使用__init__()初始化一些需要传入的参数及数据集的调用
#初始化文件路径或文件名列表
def __init__(self,txt,imgs=None,transform=None,target_transform=None, loader=default_loader):
super(MyDataset,self).__init__() #对继承自父类的属性进行初始化
fh=open(txt,'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
imgs=[]
for line in fh: #迭代该列表按行循环txt文本中的内容
line=line.rstrip('\n') # 删除 本行string 字符串末尾的指定字符
words=line.split(',')#通过每行的逗号来分开每一行的数
imgs.append([words[0],int(words[1])])#word0图片信息,word1标签信息
self.imgs=imgs
self.transform = transform
self.target_transform=target_transform
self.loader=loader
#使用__getitem__()对数据进行预处理并返回想要的信息
def __getitem__(self,index):#用于按照索引读取每个元素的具体内容
fn,label=self.imgs[index]
#fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
img=self.loader(fn)
#按照路径读取图片
if self.transform is not None:
img=self.transform(img)
#print("数据标签转换",img)
#数据标签转换成tensor
return img,label #return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
def __len__(self):#这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
print("长度",len(self.imgs[0]))
return len(self.imgs)
transform = transforms.Compose([transforms.Resize([img_x, img_y]),#改变形状
torchvision.transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data=MyDataset(txt=root+'train.txt', transform=transform)
#数据集制作完成
#数据加载器使用
data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=3)
############################################################################数据定义及加载器结束###############################
class VGG16(nn.Module):
def __init__(self,init_weights=True):
super(VGG16, self).__init__()
# 3 * 224 * 224
self.conv1_1 = nn.Conv2d(1, 64, 3) # 64 * 222 * 222
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 222* 222
self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 64 * 112 * 112
self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 110 * 110
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # 128 * 110 * 110
self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 128 * 56 * 56
self.conv3_1 = nn.Conv2d(128, 256, 3) # 256 * 54 * 54
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 256 * 28 * 28
self.conv4_1 = nn.Conv2d(256, 512, 3) # 512 * 26 * 26
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 512 * 14 * 14
self.conv5_1 = nn.Conv2d(512, 512, 3) # 512 * 12 * 12
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 512 * 7 * 7
# view
self.fc1 =nn.Sequential(nn.Linear(512 * 7 * 7, 250),
nn.ReLU(inplace=True),
nn.Linear(250, 250),
nn.ReLU(inplace=True),
)
# self.fc1 = nn.Linear(512 * 7 * 7, 2)
# self.fc2 = nn.Linear(2, )
# self.fc3 = nn.Linear(4096, 1000)
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x.size(0)即为batch_size
in_size = x.size(0)
out = self.conv1_1(x) # 222
#print("编码器第一次卷积1",out.size())
out = F.relu(out)
out = self.conv1_2(out) # 222
#print("编码器第一次卷积2",out.size())
out = F.relu(out)
out,indices1 = self.maxpool1(out) # 112
#print("编码器第一次池华",out.size())
out = self.conv2_1(out) # 110
out = F.relu(out)
out = self.conv2_2(out) # 110
out = F.relu(out)
# print("编码器第二次卷积2",out.size())
out ,indices2= self.maxpool2(out) # 56
#print("编码器第二次池华",out.size())
out = self.conv3_1(out) # 54
out = F.relu(out)
out = self.conv3_2(out) # 54
out = F.relu(out)
out = self.conv3_3(out) # 54
out = F.relu(out)
out ,indices3= self.maxpool3(out) # 28
#print("编码器第三次池华",out.size())
out = self.conv4_1(out) # 26
out = F.relu(out)
out = self.conv4_2(out) # 26
out = F.relu(out)
out = self.conv4_3(out) # 26
out = F.relu(out)
out ,indices4= self.maxpool4(out) # 14
#print("编码器第四次池华",out.size())
out = self.conv5_1(out) # 12
out = F.relu(out)
out = self.conv5_2(out) # 12
out = F.relu(out)
out = self.conv5_3(out) # 12
#print("编码器第五次卷积3",out.size())
out = F.relu(out)
out,indices5 = self.maxpool5(out) # 7
print("编码器第五次池华",out.size())
# print("indices5",indices5.size())
# 展平
#out = F.dropout(out, p=0.5, training=self.training)
#out = F.relu(out)
out = out.view(in_size, -1)
out = self.fc1(out)
return out,indices1,indices2,indices3,indices4,indices5
class DecoderCNN1(nn.Module):
def __init__(self):
super(DecoderCNN1,self).__init__()
self.fc=nn.Sequential(
nn.Linear(250,512 * 7 * 7),
nn.ReLU(inplace=True),
)
#######第一次反池华
self.unpool1= nn.MaxUnpool2d(2, stride=2, padding=1)
#第一批反卷积1
self.convtran1=nn.Sequential(
nn.ConvTranspose2d(512,512,3,1,1,0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
#第一批反卷积2
self.convtran12=nn.Sequential(
nn.ConvTranspose2d(512,512,3,1,1,0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
#第一批反卷积3
self.convtran13=nn.Sequential(
nn.ConvTranspose2d(512,512,3,1,0,0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
#######第二次反池华
self.unpool2= nn.MaxUnpool2d(2, stride=2, padding=1)
#第二批反卷积1
self.convtran2=nn.Sequential(
nn.ConvTranspose2d(512,512,3,1,1,0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
#第二批反卷积2
self.convtran22=nn.Sequential(
nn.ConvTranspose2d(512,512,3,1,1,0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
#第二批反卷积3
self.convtran23=nn.Sequential(
nn.ConvTranspose2d(512,256,3,1,0,0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
#######第三次反池华
self.unpool3= nn.MaxUnpool2d(2, stride=2, padding=1)
#第三批反卷积1
self.convtran3=nn.Sequential(
nn.ConvTranspose2d(256,256,3,1,1,0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
#第三批反卷积2
self.convtran32=nn.Sequential(
nn.ConvTranspose2d(256,256,3,1,1,0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
)
#第三批反卷积3
self.convtran33=nn.Sequential(
nn.ConvTranspose2d(256,128,3,1,0,0),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
)
#######第四次反池华
self.unpool4= nn.MaxUnpool2d(2, stride=2, padding=1)
#第四批反卷积1
self.convtran4=nn.Sequential(
nn.ConvTranspose2d(128,128,3,1,1,0),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
)
#第四批反卷积2
self.convtran42=nn.Sequential(
nn.ConvTranspose2d(128,64,3,1,0,0),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
#######第五次反池华
self.unpool5= nn.MaxUnpool2d(2, stride=2, padding=1)
#第五批反卷积1
self.convtran5=nn.Sequential(
nn.ConvTranspose2d(64,64,3,1,1,0),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
#第五批反卷积2
self.convtran52=nn.Sequential(
nn.ConvTranspose2d(64,1,3,1,0,0),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
def forward(self, x,indices1,indices2,indices3,indices4,indices5):
in_size = x.size(0)
x=self.fc(x)
x=x.view(in_size,512,7,7)
#print("解码器输入",x.size())
out=self.unpool1(x,indices5)
#print("第一次最大池华",out.size())
out=self.convtran1(out)
out=self.convtran12(out)
out=self.convtran13(out)
out=self.unpool2(out,indices4)
out=self.convtran2(out)
out=self.convtran22(out)
out=self.convtran23(out)
out=self.unpool3(out,indices3)
out=self.convtran3(out)
out=self.convtran32(out)
out=self.convtran33(out)
out=self.unpool4(out,indices2)
out=self.convtran4(out)
out=self.convtran42(out)
out=self.unpool5(out,indices1)
out=self.convtran5(out)
out=self.convtran52(out)
return out
###############################################################开始定义学习率等等#############################################
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
cnn_encoder = VGG16(init_weights=flag).to(device)
cnn_decoder=DecoderCNN1().to(device)
#c_params=list(cnn_encoder.parameters())+list(cnn_decoder.parameters())
optimizer = torch.optim.Adam(itertools.chain(cnn_encoder.parameters(),cnn_decoder.parameters()), lr=learning_rate)#优化cnn编码器和rnn解码器的参数
loss_func=torch.nn.MSELoss()
if __name__=="__main__":
if resume:
#恢复上次的训练状态
print("Resume from checkpoint...")
checkpoint=torch.load('D:/sign_first/video_code/checkpoint/check_point.pth')
cnn_encoder.load_state_dict(checkpoint['model1_state_dict'])
cnn_decoder.load_state_dict(checkpoint['model2_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
initepoch=checkpoint['epoch']+1
print("断点是第多少次",initepoch)
for epoch in range(initepoch,epochs):
print("epoch:",epoch)
for step,(b_x,b_y) in enumerate(data_loader):
print("step:",step)
#print("b_x",b_x.size())
optimizer.zero_grad()#老三步
out,indices1,indices2,indices3,indices4,indices5=cnn_encoder(b_x)
weneed=cnn_decoder(out,indices1,indices2,indices3,indices4,indices5)
#print("weneed",weneed.size())
loss=loss_func(weneed,b_x)
print("loss",loss)
loss.backward()
optimizer.step()#老三步
'''
for p in optimizer.param_groups:
print("#######################################")
#p['lr'] *= 0.9
if(loss>0.1):
p['lr'] =0.1
elif((loss>0.01)and(loss<0.09)):
p['lr'] =0.01
elif((loss>0.001)and(loss<0.009)):
p['lr'] =0.001
elif((loss>0.0001)and(loss<0.0009)):
p['lr'] =0.0001
elif((loss>0.00001)and(loss<0.00009)):
p['lr'] =0.00005
# elif((loss<0.005)&(loss>0.00001)):
# p['lr'] =0.00009
print("lr的学习率是",optimizer.state_dict()['param_groups'][0]['lr'])
'''
'''
#用来断点续训部分
'''
checkpoint={
'epoch':epoch,
'model1_state_dict':cnn_encoder.state_dict(),
'model2_state_dict':cnn_decoder.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(checkpoint,'D:/sign_first/video_code/checkpoint/check_point.pth')
torch.save(cnn_encoder,'encodernet1.pkl')
torch.save(cnn_decoder,'decodernet2.pkl')