解析PEN_NET(基于金字塔式图像修复)——损失函数(感知损失,风格损失,
上一篇博客我简单介绍了下基于金字塔式的图像修复
我clone了该项目,并逐步分析,本篇文章主要讲解一下这个项目的损失函数的定义
传统基于均方误差损失
从一幅缺失的图像转换到一幅修复的图像
我们最常想到的就是MSE均方误差损失
通过比较每个像素的误差,最后取平方再平均,得到一个loss损失值,并反向传播,对各个层进行梯度下降
事实上,均方误差表现的也不错,但是在一些细节,特征上缺失,造成局部的模糊
以笔者的一个简单的图像修复demo,它解码出来的图片是这样的
可以很明显看出基于MSE损失恢复出来的图像,在修复理发店的字体的时候表现并不是很好,那几个字母其实是很模糊的,但是大体形状是修复出来的。
感知损失
基于均方误差损失的缺陷,2016年李飞飞团队提出感知损失
该网络核心思想是这样的,即然MSE的细节特征修复的不是很好,那我就找一个预训练出来的网络来捕捉其图像的特征细节,根据生成的特征图,计算其L1损失函数,作为损失值的一部分
在该项目的代码中,作者使用的是VGG19网络,将图像都流入该网络,并将其中的relu层中的特征图拿出来进行L1loss的计算。
风格损失Style Loss
风格损失其实用在风格迁移上较多,它的思想就是利用格拉姆矩阵(gram)来进行特征图差异的计算。我们一个预训练好的模型,它不同层输出的特征图所表达的特征是不一样的,浅层的特征是比较细节,而深层的特征是比较抽象的。而格拉姆矩阵其实就是计算两个向量的相关性,通过将不同特征图向量进行格拉姆矩阵计算后,能更好地把握整体图像的风格
格拉姆矩阵
比如一个i行j列的矩阵Xij
它的格拉姆矩阵实质就是X*XT
也就是向量xi和xj的内积
它可以表达不同通道上样式特征的相关性
实现
该项目的作者仍然采用的是vgg19网络来抽取特征,首先计算对特征进行reshape,并且计算出它对应的转置,然后计算其内积,最后除以矩阵中元素的个数(不需要除以batch_size)
在call这个调用方法中
它对原始图像和生成图像所抽取的特征图并不完全一样,最后合并到一起,计算其L1loss
对抗损失
由于整体是使用的GAN结构,所以还需要引入对抗损失
作者根据不同的gan结构选用了不同的损失函数
其他原理比较简单,如果是真实的样本,则返回真实的label标签,也就是1,如果是假的样本则返回假的标签0
整个项目的损失函数就用了这三种,其他模块的解析会在后续继续更新