本文主要参考B站的白板推导系列,推荐大家观看。
什么是变分推断
X : 观测变量
Z:latent variable + parameter
在变分推断中,样本点X被称为观测变量(observed data),未知参数和潜变量被称为不可观测变量,都用Z来表示。
我们的模型一般都是根据观测数据来求Z的后验分布,也就是求P(Z∣X),但是有的时候,P(Z∣X)是不好求的,所以可以尝试用一个容易求解的分布Q来逼近P(Z∣X),当Q和P(Z∣X)之前的差距很小时就可以用Q来近似代替P(Z∣X)。
数学推导求解Q
根据概率公式,我们有
logP(X)=logP(X,Z)−logP(Z∣X)(1)
(1)式右边的两项log里面的部分同时除以Q(Z),(1)式的值依然相等。即
logP(X)=logQ(Z)P(X,Z)−logQ(Z)P(Z∣X)(2)
(2)式两边同时乘以Q(Z),并对Z求积分得。
左边=右边===∫ZlogP(X)Q(Z)dZ=logP(X)∫Z(Q(Z)logQ(Z)P(X,Z)−Q(Z)logQ(Z)P(Z∣X))dZ∫ZQ(Z)logQ(Z)P(X,Z)dZ−∫ZQ(Z)logQ(Z)P(Z∣X))dZL(Q)+KL(Q∣∣P)
上式中,第一项记为L(Q),第二项连同负号是KL散度的形式,表示的是后验分布P(Z∣X)与分布Q(Z)之间的距离。
所以有
logP(X)=L(Q)+KL(Q∣∣P)(3)
由于KL(Q∣∣P)是大于等于0的,所以L(Q)<=logP(X),此时也称L(Q)是ELBO(证据下界)。值得注意的是,此时的L(Q)是分布Q的函数(也称为泛函),当数据给定的时候,logP(X)是不变的,所以如果我们让L(Q)最大化,也就是会让KL(Q∣∣P)最小化,那么我们就会找到一个Q,从而用这个Q近似后验分布P(Z∣X)。
根据平均场理论,我们可以将Z分割为独立的M份,也就是
Q(Z)=i∏MQi(Zi)(4)
由前面的推导知,
L(Q)=∫ZQ(Z)logQ(Z)P(X,Z)dZ(5)
将(4)代入(5)式,得
L(Q)=∫Zi∏MQi(Zi)logP(X,Z)dZ−∫Zi∏MQi(Zi)logi∏MQi(Zi)dZ=∫Zi∏MQi(Zi)logP(X,Z)dZ−∫Zi∏MQi(Zi)i∑MlogQi(Zi)dZ(6)
考虑其中的某一项Qj(Zj),分别计算(6)式的两部分。



得到Qi(Zi)之后,根据公式(4)便得到了我们要求的分布Q。