CPC(二):代码阅读解析

源代码链接:https://github.com/davidtellez/contrastive-predictive-coding

代码主要结构分为三个部分:train_model.py, benchmark_model.py, data_util.py

data_util.py:

主要用于提供训练所需的数据,生成的图片如下

代码阅读思考:

文件data_util里面一共有几大类?每一个类别的作用是什么?

文件data_util一共分了四个大类分别是:

①MnistHandler():用于梳理MNSIT数据,一共定义了6个函数,init(), load_dataset(), process_batch(), get_batch(), get_batch_by_lables(), get_n_samples()

init():下载数据,将lena image储存到记忆库中

load_dataset(): 下载数据,这个函数是在github上面,MNIST下载函数直接复制粘贴过来的,返回值是x_train, y_train, x_val, y_val, x_test, y_test

process_batch(): 用于转化MNIST数据,将图片从28x28转化到64x64,将图片转化为RGB图像(原图像是黑白的)。将图像二值化(此处的二值化,是否就是以像素的形式以0和1进行表示?)从图片lena中随机修剪一小块,将像素的颜色转化,该颜色是由之前的从图片lena中得到的。然后将图片缩放到[-1, 1]的范围。返回值是batch

get_batch(): 用于选择一个子集,随机选择采样,并将采样的batch进行数据处理,返回值是batch.astype('float32'), labels.astype('int32')

get_batch_by_labels(): 用于选择一个子集,选择匹配标签的样本,重新找到样本,使用process_batch进行batch处理,返回值是batch.astype('float32'), labels.astype('int32')

get_n_samples(): 根据不同的要求选择样本,要求分别有train, valid, test. 返回值是y_len

 

②MnistGenerator():用于提供MNIST的数据,一共定义了5个函数,init(), iter(), next(), len(), next()

init():用于设置参数,并初始化MNIST数据。

iter():  返回self

len(): 返回n_batches

next(): 返回 x, y_h,其中y_h是y经过独热编码处理过的数据。

 

③MnistGenerator(): 用于提供生成的分类数目的列表,一共定义了5个函数,init(), iter(), next(), len(), next()

init():用于设置参数,并初始化MNIST数据。

iter():  返回self

len(): 返回n_batches

next(): 返回[x_images[idxs, ...], y_images[idxs, ...], sentence_labels[idxs, ...]],生成语句,设置正样本的顺序预测,保存sentence,重建实际图像,组合batch,并将之随机化。

 

④SameNumberGenerator(): 用于提供相似数字的列表,一共定义了5个函数,init(), iter(), next(), len(), next()

init():用于设置参数,并初始化MNIST数据。

iter():  返回self

len(): 返回n_batches

next(): 返回[x_images[idxs, ...], y_images[idxs, ...], sentence_labels[idxs, ...]],生成语句,设置正样本的顺序预测,保存sentence,重建实际图像,组合batch,并将之随机化。

 

⑤ 单独定义了一个函数plot_sequences(),用于将图片绘出来。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值