代码实例:教你实现infoGAN
实例:构建infoGAN生成MNIST模拟数据
本例演示在MNISTt数据集上使用infoGan网络模型生成模拟数据,并且加入标签信息的loss函数同时实现了AC-GAN的网络。其中的D和G都是用卷积网络来实现的,相当于DCGAN上面的infoGAN例子。
案例描述
通过使用InfoGAN网络学习MNIST数据特征,生成以假乱真的MNIST模拟样本,并发现内部潜在特征信息。
具体实现可以分为如下几个步骤:
1. 引入头文件并加载MNIST数据
假设MNIST数据放在本地磁盘跟目录的data下。本例中将使用前面介绍的slim模块构建网络结构,所以需要引入slim。当然也可以不用slim,引入slim的目的就是为了编写代码比较方便,不用考虑输入维度即相关权重的定义。最主要是slim还对反卷积有封装,下文会用到。
建立2个噪声数据(一般噪声和隐含信息)与label结合放到生成器中,生成模拟样本,然后将模拟样本和真实样本分别输入到判别器中,生成判别结果,重构造的隐含信息,还有样本标签。
在优化时,让判别器对真实的判别结果为1,对模拟数据的判别结果为0来做loss。对生成器让判别结果为1来做loss。
3.定义生成器与判别器
由于是先从模拟噪声数据来恢复样本,所以在生成器中。要使用反卷积函数。这里通过2个全连接,再接入两个反卷积完成样本的模拟生成的。并且每一层都有BN归一化处理。
令一般噪声的维度为38对应节点为z_rand,隐含信息维度为2对应节点z_con,二者都是符合标准高斯分布的随机数。将它们与one_hot转换后的标签连接一起放到生成器中。
5. 定义损失函数与优化器
对于判别器中判别结果的loss有两个:真实输入的结果与模拟输入的结果,将二者和在一起生成loss_d。对于生成器的loss为自己输出的模拟数据让它在判别器中为真,定义为loss_g。
剩下还要定义网络中共有的loss值:真实的标签与输入真实样本判别出的标签、真实的标签与输入模拟样本判别出的标签、隐含信息的重构误差。定义好后创建两个优化器,将它们放到对应的优化器中。
这里用了个技巧将判别器的学习率设小,将生成器的学习率设大些。这么做是为了让生成器有更快的进化速度来模拟真实数据。优化同样是用AdamOptimizer方法。具体代码如下:
6. 开始训练与测试
建立session,在循环里使用run来运行前面构建的两个优化器。
7.可视化
这部分通过两种显示来可视化结果:生成原样本与对应的模拟数据图片、生成隐含信息对应的图片。生成原样本与对应的模拟数据图片会将对应的分类以及预测分类还有隐含信息一起打印出来。生成隐含信息对应的图片中将在整个[0,1]空间里抽样与样本的标签混合一起,生成模拟数据。
本例演示在MNISTt数据集上使用infoGan网络模型生成模拟数据,并且加入标签信息的loss函数同时实现了AC-GAN的网络。其中的D和G都是用卷积网络来实现的,相当于DCGAN上面的infoGAN例子。
案例描述
通过使用InfoGAN网络学习MNIST数据特征,生成以假乱真的MNIST模拟样本,并发现内部潜在特征信息。
具体实现可以分为如下几个步骤:
1. 引入头文件并加载MNIST数据
假设MNIST数据放在本地磁盘跟目录的data下。本例中将使用前面介绍的slim模块构建网络结构,所以需要引入slim。当然也可以不用slim,引入slim的目的就是为了编写代码比较方便,不用考虑输入维度即相关权重的定义。最主要是slim还对反卷积有封装,下文会用到。
代码12-1 Mnistinfogan
建立2个噪声数据(一般噪声和隐含信息)与label结合放到生成器中,生成模拟样本,然后将模拟样本和真实样本分别输入到判别器中,生成判别结果,重构造的隐含信息,还有样本标签。
在优化时,让判别器对真实的判别结果为1,对模拟数据的判别结果为0来做loss。对生成器让判别结果为1来做loss。
3.定义生成器与判别器
由于是先从模拟噪声数据来恢复样本,所以在生成器中。要使用反卷积函数。这里通过2个全连接,再接入两个反卷积完成样本的模拟生成的。并且每一层都有BN归一化处理。
代码12-1 Mnistinfogan(续)
对于判别器的输入是真正的样本,同样的也是经过两次卷积,在接两次全连接,生成的数据可以分别接不同的输出层,来产生不同的结果:1维输出对应判别结果1 还是0;10维输出对应分类结果;2维输出对应隐含维度信息。
令一般噪声的维度为38对应节点为z_rand,隐含信息维度为2对应节点z_con,二者都是符合标准高斯分布的随机数。将它们与one_hot转换后的标签连接一起放到生成器中。
代码12-1 Mnistinfogan(续)
5. 定义损失函数与优化器
对于判别器中判别结果的loss有两个:真实输入的结果与模拟输入的结果,将二者和在一起生成loss_d。对于生成器的loss为自己输出的模拟数据让它在判别器中为真,定义为loss_g。
剩下还要定义网络中共有的loss值:真实的标签与输入真实样本判别出的标签、真实的标签与输入模拟样本判别出的标签、隐含信息的重构误差。定义好后创建两个优化器,将它们放到对应的优化器中。
这里用了个技巧将判别器的学习率设小,将生成器的学习率设大些。这么做是为了让生成器有更快的进化速度来模拟真实数据。优化同样是用AdamOptimizer方法。具体代码如下:
代码12-1 Mnistinfogan(续)
6. 开始训练与测试
建立session,在循环里使用run来运行前面构建的两个优化器。
代码12-1 Mnistinfogan(续)
测试部分分别使用loss_d和loss_g的eval来完成。上面代码运行后得到如下输出:
7.可视化
这部分通过两种显示来可视化结果:生成原样本与对应的模拟数据图片、生成隐含信息对应的图片。生成原样本与对应的模拟数据图片会将对应的分类以及预测分类还有隐含信息一起打印出来。生成隐含信息对应的图片中将在整个[0,1]空间里抽样与样本的标签混合一起,生成模拟数据。
代码12-1 Mnistinfogan(续)
上面代码运行后,生成如下结果: