图像自编码器,在UCF101以及---数据集上已进行验证效果较好

本文介绍了一种使用VGG16构建的图像自编码器,结合自定义数据加载器,该模型已在UCF101数据集上得到验证,表现出良好的效果。通过比较原图和生成图,展示了模型的图像重构能力。
摘要由CSDN通过智能技术生成

图片自编码器(自定义数据加载器+VGG16+transposeVGG16)

vGG16自编码器
在这里插入图片描述
原图
生成在这里插入图片描述
生成图

// An highlighted block
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 21 15:21:29 2021

@author: HUANGYANGLAI
"""
from time import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import itertools


CNN_embed_dim = 512      # latent dim extracted by 2D CNN
img_x, img_y = 224, 224  # resize video 2d frame size(可能更改图片尺寸)
dropout_p = 0        # dropout probability(随机失活)

#训练参数
k = 2         # 这里没用
epochs = 5000   # (迭代次数)
batchsize = 8  #(批处理)

learning_rate = 0.00005#本设计汇总可使用动态的方式来调整学习率

log_interval = 10 
flag=True#vgg官方数据初始化数据

resume=True#是否断点重新连接
###############################################################################进行自定义数据加载###############################
root='D://sign_first//video//'
print(root)

#定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')#路径必须指名哪一张图,不能是指定文件夹

#创建自己的类: MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(Dataset):
     #使用__init__()初始化一些需要传入的参数及数据集的调用
     #初始化文件路径或文件名列表
     def __init__(self,txt,imgs=None,transform=None,target_transform=None, loader=default_loader):
         super(MyDataset,self).__init__() #对继承自父类的属性进行初始化
         fh=open(txt,'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
         imgs=[]
         for line in fh: #迭代该列表按行循环txt文本中的内容
             line=line.rstrip('\n') # 删除 本行string 字符串末尾的指定字符
             words=line.split(',')#通过每行的逗号来分开每一行的数
             imgs.append([words[0],int(words[1])])#word0图片信息,word1标签信息
         self.imgs=imgs
         self.transform = transform
         self.target_transform=target_transform
         self.loader=loader
         
 #使用__getitem__()对数据进行预处理并返回想要的信息
     def __getitem__(self,index):#用于按照索引读取每个元素的具体内容
         fn,label=self.imgs[index]
                         #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
         img=self.loader(fn)
                         #按照路径读取图片
         if self.transform is not None:
             img=self.transform(img)
             #print("数据标签转换",img)
             
             #数据标签转换成tensor
             
         return img,label #return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
     
     def __len__(self):#这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
         print("长度",len(self.imgs[0]))
         return len(self.imgs)
     

transform = transforms.Compose([transforms.Resize([img_x, img_y]),#改变形状
                                torchvision.transforms.Grayscale(num_output_channels=1),
                                transforms.ToTensor(),
                                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

                                ])
       
train_data=MyDataset(txt=root+'train.txt', transform=transform)
#数据集制作完成
#数据加载器使用
data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=3) 
############################################################################数据定义及加载器结束###############################

class VGG16(nn.Module):
    def __init__(self,init_weights=True):
        super(VGG16, self).__init__()
        
        # 3 * 224 * 224
        self.conv1_1 = nn.Conv2d(1, 64, 3) # 64 * 222 * 222
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 222* 222
        self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 64 * 112 * 112
        
        self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 110 * 110
        self.conv
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值