SeqGAN Sequence Generate Adversarial Nets with Policy Gradient(阅读理解)

SeqGAN Sequence Generate Adversarial Nets with Policy Gradient(阅读理解)

SeqGAN在目前是在使用生成对抗网络解决文本序列生成问题的最有影响力的一篇文章,作者针对对抗生成网络难以解决序列生成问题,提出了很多十分有价值的方法

问题:

(1)在GANs中,Generator是通过随机抽样作为开始,然后根据模型的参数进行确定性的转化。通过generative model G的输出,discriminative model D计算的损失值,根据得到的损失梯度去指导generative model G做轻微改变,从而使G产生更加真实的数据。而在文本生成任务中,G通常使用的是LSTM,那么G传递给D的是一堆离散值序列,即每一个LSTM单元的输出经过softmax之后再取argmax或者基于概率采样得到一个具体的单词,那么这使得梯度下架很难处理

(2)GAN只能评估出整个生成序列的score/loss,不能够细化到去评估当前生成token的好坏和对后面生成的影响

SeqGAN模型将强化学习和对抗思想的结合,解决非连续序列生成的问题,产生可用于文本序列生成的模型

重点:

利用强化学习修正生成器Generator,作者提出生成器G的目标是生成sequence来最大化reward的期望

SeqGAN Sequence Generate Adversarial Nets with Policy Gradient(阅读理解)

在这里把这个reward的期望叫做J(θ)。就是在s0和θ的条件下,产生某个完全的sequence的reward的期望。其中Gθ()部分可以轻易地看出就是Generator Model。而Q()(我在这里叫它Q值)在文中被叫做一个sequence的action-value function 。这个J(θ),也就是我们生成模型想要最大化的函数
Q值需要一个完整的序列才能从判别器返回一个评估值,不完整的序列的reward没有实际的意义,在这里,作者首次提出使用蒙特卡洛搜索模拟完整序列生成

利用MCTS方法,作者提出用部分真实序列作为起始输入,随机生成完整的序列,通过模拟,近似出完整的序列,同时,作者在采用了增量的方式逼近真实序列,假设第一次的部分真实序列长度为a,则下次迭代输入的真实序列长度为a+1,以此类推生成。使用蒙特卡洛搜索补全的所有可能的sequence全都计算reward,然后求平均。

SeqGAN Sequence Generate Adversarial Nets with Policy Gradient(阅读理解)


贡献

第一次将GANs和文本序列生成问题进行了很有效的结合,之前已有研究,但是本文提出的方法为后续研究做出了非常的铺垫,大多数相关的研究都以此展开

作者首次结合强化学习算法解决RNN生成器的优化问题,传统的方法都是基于MLE,会随着时间的推移出现bias累加的问题,而GANs的使用大大缓解了之前的问题

作者使用MCTS,解决了不完整序列无法评估的问题,同时通过一步一步训练Generator,有效的提高了生成器的准确率