2020-10-26

TOC]生成网络GAN
基本思想

GAN的主要结构包括一个生成器G(Generator)和一个判别器D(Discriminator)。对于GAN,一个简单的理解是可以将其看做博弈的过程,我们可以将生成模型和判别模型看作博弈的双方,比如在识别手写数字的数据集的过程中:
生成模型G相当于制造新图像的一方,其目的是根据看到的图片情况和判别器的识别技术,去尽量生成更加真实的、识别不出的假币。
判别模型D相当于识别图像的一方,其目的是尽可能的识别出生成器制造的假图像。 这样通过双方的较量和朝目的的改进,使得最后能达到生成模型能尽可能真的、识假者判断不出真假的纳什均衡效果(真假概率都为0.5)。

GAN具体过程

定义一个模型来作为生成器,能够输入一个向量,输出手写数字大小的像素图像。
定义一个分类器来作为判别器用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签。

训练流程如下:
1.初始化判别器D的参数 和生成器G的参数 。
2.从真实样本中采样m 个样本 ,从先验分布噪声中采样 m个噪声样本并通过生成器获取 m个生成样本 。固定生成器G,训练判别器D尽可能好地准确判别真实样本和生成样本,尽可能大地区分正确样本和生成的样本。
循环k次更新判别器之后,使用较小的学习率来更新一次生成器的参数,训练生成器使其尽可能能够减小生成样本与真实样本之间的差距,也相当于尽量使得判别器判别错误。
多次更新迭代之后,最终理想情况是使得判别器判别不出样本来自于生成器的输出还是真实的输出。亦即最终样本判别概率均为0.5。

之所以要先训练判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新。更直观的理解可以参考下图:
2020-10-26

黑色虚线表示真实的样本的分布情况,蓝色虚线表示判别器判别概率的分布情况,绿色实线表示生成样本的分布。 Z 表示噪声, Z 到 x 表示通过生成器之后的分布的映射情况。