text to image(七):《TAC-GAN 》
继续介绍文本生成图像的工作,本篇博客要给出的是2017年3月19号发表于arXiv的《TAC-GAN – Text Conditioned Auxiliary Classifier Generative Adversarial Network》 。
论文地址:https://arxiv.org/abs/1703.06412v2
源码地址:https://github.com/dashayushman/TAC-GAN
论文结构非常清晰明了,通俗易懂,点赞!
一、相关工作
首先是GAN网络的理解:https://blog.****.net/zlrai5895/article/details/80648898
作者将实验结果与StackGAN做了对比,StackGAN的相关工作:https://blog.****.net/zlrai5895/article/details/81292167
文章中的模型结构是在AC-GAN的基础上做了扩展,下面简单介绍一下AC-GAN。
AC-GAN是一种condition GAN,它以类别标签为condition,并且鉴别器不仅仅区分真实图片和合成图片,还给它们分配标签。
AC-GAN的输入:
(1)noise z
(2)文本描述的text embedding
用代表生成器生成的图像,生成的图片与label class (c)和z有关,则有
损失包含了两部分:图片和标签
详细可参考论文:《Conditional image synthesis with auxiliary classifier gans》 。
此外,TAC-GAN模型结构实现参考了DCGAN,来自论文《Unsupervised representation learning with deep convolutional generative adversarial networks.》
二、基本思想及成果
TAC-GAN使用 condition GAN的思想,使用文本合成了分辨率128*128的图像。相比StackGAN,它的inception score有了7.8%的提升。不过分辨率没有StackGAN的高(256*256)。
inception score 的定义可以参考https://blog.****.net/sparkkkk/article/details/72565975。
三、数据集介绍:
实验使用的数据集是Oxford-102。
1.数据集简介:
数据集由102类产自英国的花卉组成。每类由40-258张图片组成。
2、数据集结构:
下载完数据集,解压后可得到一个包含8189张.jpg格式的图片。数据集标签需额外下载。
数据集结构:
-imagelabels.mat:总共有8189列,每列上的数字代表类别号。
-setid.mat
-trnid.mat:总共有1020列,每10列为一类花卉的图片,每列上的数字代表图片号。
-valid.mat:总共有1020列,每10列为一类花卉的图片,每列上的数字代表图片号。
-tstid.mat:总共有6149列,每一类花卉的列数不定,每列上的数字代表图片号。
四、模型结构
先上整体结构图,另外,下文提到的向量具体数值均是为了便于理解,并不完全与代码运行后的实际数值相同。
1、text embedding
文章使用Skip-Thought 来从caption中生成text embedding.关于Skip-Thought的理解:
https://zhuanlan.zhihu.com/p/32953049
来自文章《Skip-Thought Vectors》,这篇论文使用从书籍中提取的连续文本,训练了一个编码器、解码器网络,借鉴了word2vec中skip-gram模型,通过一句话来预测这句话的上一句和下一句。模型被称为skip-thoughts,生成的向量称为skip-thought vector。模型采用了当下流行的端到端框架,通过搜集了大量的小说作为训练数据集,将得到的模型中encoder部分作为feature extractor,可以给任意句子生成vector。我们就利用encoder 部分产生text embedding。
2、生成器
首先,数据集中的每个数据由三部分组成:
代表了一幅图片
是这幅图片的k个文本描述。
是图片对应的标签。
训练时,对于某张图片随机挑一个
来生成一个text_embedding(
),接一个全连接层得到
,且
,将其与noise连接在一起,得到
连接到一个输出为的全连接层,并将输出reshape成
。将它输入生成器后续的卷积网络,上采样成128*128*3的
。
2、鉴别器
鉴别器使用了这样的一个集合:
(图片、类别、文本描述)
又可以被写为
鉴别器的第一个输入是集合中的任意一个 I,图像被下采样到
鉴别器的第二个输入是,它被replicated到
二者concat,然后经过一系列卷积层等,最后输出两个全连接层,一个有1个神经元,另一个有个神经元,第一个FC提供是真还是假的概率分布
,第二个提供类别的概率分布
。
鉴别器的设计受到了GAN-CLS的启发(来自《Generative adversarial text to image synthesis》)。
可以参考博客https://blog.****.net/zlrai5895/article/details/81416053
五、损失函数
被计为判别器的输出和每幅图像的期望之间的二进制交叉熵之和。
鉴别器希望能真的判定为真,假的判定为假。
被即为输出类别与期望类别的熵之和。
鉴别器损失+
生成器损失要依靠鉴别器的鉴别结果,希望能欺骗鉴别器。
也是图像和类别两部分。
六、其他
1、模型结构比较容易扩展,可以更换condition。
2、作者认为可以尝试一下使用多阶段的GAN,可能会有更好的效果。