图片自编码器(自定义数据加载器+VGG16+transposeVGG16)
vGG16自编码器
原图
生成图
// An highlighted block
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 21 15:21:29 2021
@author: HUANGYANGLAI
"""
from time import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import itertools
CNN_embed_dim = 512 # latent dim extracted by 2D CNN
img_x, img_y = 224, 224 # resize video 2d frame size(可能更改图片尺寸)
dropout_p = 0 # dropout probability(随机失活)
#训练参数
k = 2 # 这里没用
epochs = 5000 # (迭代次数)
batchsize = 8 #(批处理)
learning_rate = 0.00005#本设计汇总可使用动态的方式来调整学习率
log_interval = 10
flag=True#vgg官方数据初始化数据
resume=True#是否断点重新连接
###############################################################################进行自定义数据加载###############################
root='D://sign_first//video//'
print(root)
#定义读取文件的格式
def default_loader(path):
return Image.open(path).convert('RGB')#路径必须指名哪一张图,不能是指定文件夹
#创建自己的类: MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(Dataset):
#使用__init__()初始化一些需要传入的参数及数据集的调用
#初始化文件路径或文件名列表
def __init__(self,txt,imgs=None,transform=None,target_transform=None, loader=default_loader):
super(MyDataset,self).__init__() #对继承自父类的属性进行初始化
fh=open(txt,'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
imgs=[]
for line in fh: #迭代该列表按行循环txt文本中的内容
line=line.rstrip('\n') # 删除 本行string 字符串末尾的指定字符
words=line.split(',')#通过每行的逗号来分开每一行的数
imgs.append([words[0],int(words[1])])#word0图片信息,word1标签信息
self.imgs=imgs
self.transform = transform
self.target_transform=target_transform
self.loader=loader
#使用__getitem__()对数据进行预处理并返回想要的信息
def __getitem__(self,index):#用于按照索引读取每个元素的具体内容
fn,label=self.imgs[index]
#fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
img=self.loader(fn)
#按照路径读取图片
if self.transform is not None:
img=self.transform(img)
#print("数据标签转换",img)
#数据标签转换成tensor
return img,label #return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
def __len__(self):#这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
print("长度",len(self.imgs[0]))
return len(self.imgs)
transform = transforms.Compose([transforms.Resize([img_x, img_y]),#改变形状
torchvision.transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
#transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_data=MyDataset(txt=root+'train.txt', transform=transform)
#数据集制作完成
#数据加载器使用
data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=3)
############################################################################数据定义及加载器结束###############################
class VGG16(nn.Module):
def __init__(self,init_weights=True):
super(VGG16, self).__init__()
# 3 * 224 * 224
self.conv1_1 = nn.Conv2d(1, 64, 3) # 64 * 222 * 222
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 222* 222
self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 64 * 112 * 112
self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 110 * 110
self.conv