深度学习_GAN_CGAN学习笔记
我们把监督式学习的思想也放在生成模型上,我们期待的结果是希望可以根据网络输入的标签或者参数来生成对应的输出。
但是带标签数据的生成存在很多问题,在传统的生成模型上,包括传统的神经网络和原始GAN,都无法很好地控制数据生成的模式。
这时,条件式生成对抗网络(CGAN, Conditional Generative Adversarial Networks)作为一个GAN的改进,可以通过参数的控制来指导数据的生成,其一定程度上解决了GAN生成结果的不确定性。
如果在MNIST数据集上训练原始GAN,GAN生成的图像是完全不确定的,具体生成的是数字1,还是2,还是几,根本不可控。为了让生成的数字可控,我们可以把数据集做一个切分,把数字0-9的数据集分别拆分开训练9个模型,不过这样麻烦了,也不现实。因为数据集拆分不仅仅是分类麻烦,更主要在于,每一个类别的样本少,拿去训练GAN很有可能导致欠拟合。因此,CGAN就应运而生了。我们先看一下CGAN的网络结构:
从网络结构图可以看到,对于生成器Generator,其输入不仅仅是随机噪声的采样Z,还有预生成图像的标签信息。比如对于MNIST数据生成,就是一个one-hot向量,某一维度为1则表示某个数字的图片。同样地,判别器的输入也包括样本的标签。这样就使得判别器和生成器可以学习到样本和标签之间的联系。其目标函数如下:
目标函数设计和原始GAN的整体结构基本一致,只不过生成器,判别器的输入数据是一个条件分布,即加入了额外的辅助信息Y,这个Y可以是该数据的分类标签等。在具体编程实现时需要对随机噪声采样Z和输入条件Y做一个拼接组合(级联),这样就形成了一个全新的隐含表示。
其结构图也可以用下图表示: