CPC(二):代码阅读解析
源代码链接:https://github.com/davidtellez/contrastive-predictive-coding
代码主要结构分为三个部分:train_model.py, benchmark_model.py, data_util.py
data_util.py:
主要用于提供训练所需的数据,生成的图片如下
代码阅读思考:
文件data_util里面一共有几大类?每一个类别的作用是什么?
文件data_util一共分了四个大类分别是:
①MnistHandler():用于梳理MNSIT数据,一共定义了6个函数,init(), load_dataset(), process_batch(), get_batch(), get_batch_by_lables(), get_n_samples()
init():下载数据,将lena image储存到记忆库中
load_dataset(): 下载数据,这个函数是在github上面,MNIST下载函数直接复制粘贴过来的,返回值是x_train, y_train, x_val, y_val, x_test, y_test
process_batch(): 用于转化MNIST数据,将图片从28x28转化到64x64,将图片转化为RGB图像(原图像是黑白的)。将图像二值化(此处的二值化,是否就是以像素的形式以0和1进行表示?)从图片lena中随机修剪一小块,将像素的颜色转化,该颜色是由之前的从图片lena中得到的。然后将图片缩放到[-1, 1]的范围。返回值是batch
get_batch(): 用于选择一个子集,随机选择采样,并将采样的batch进行数据处理,返回值是batch.astype('float32'), labels.astype('int32')
get_batch_by_labels(): 用于选择一个子集,选择匹配标签的样本,重新找到样本,使用process_batch进行batch处理,返回值是batch.astype('float32'), labels.astype('int32')
get_n_samples(): 根据不同的要求选择样本,要求分别有train, valid, test. 返回值是y_len
②MnistGenerator():用于提供MNIST的数据,一共定义了5个函数,init(), iter(), next(), len(), next()
init():用于设置参数,并初始化MNIST数据。
iter(): 返回self
len(): 返回n_batches
next(): 返回 x, y_h,其中y_h是y经过独热编码处理过的数据。
③MnistGenerator(): 用于提供生成的分类数目的列表,一共定义了5个函数,init(), iter(), next(), len(), next()
init():用于设置参数,并初始化MNIST数据。
iter(): 返回self
len(): 返回n_batches
next(): 返回[x_images[idxs, ...], y_images[idxs, ...], sentence_labels[idxs, ...]],生成语句,设置正样本的顺序预测,保存sentence,重建实际图像,组合batch,并将之随机化。
④SameNumberGenerator(): 用于提供相似数字的列表,一共定义了5个函数,init(), iter(), next(), len(), next()
init():用于设置参数,并初始化MNIST数据。
iter(): 返回self
len(): 返回n_batches
next(): 返回[x_images[idxs, ...], y_images[idxs, ...], sentence_labels[idxs, ...]],生成语句,设置正样本的顺序预测,保存sentence,重建实际图像,组合batch,并将之随机化。
⑤ 单独定义了一个函数plot_sequences(),用于将图片绘出来。