变分推断一(基于平均场理论求解Q)

本文主要参考B站的白板推导系列,推荐大家观看。

什么是变分推断

XX : 观测变量
ZZ:latent variable + parameter
在变分推断中,样本点XX被称为观测变量(observed data),未知参数和潜变量被称为不可观测变量,都用ZZ来表示。

我们的模型一般都是根据观测数据来求ZZ的后验分布,也就是求P(ZX)P(Z|X),但是有的时候,P(ZX)P(Z|X)是不好求的,所以可以尝试用一个容易求解的分布QQ来逼近P(ZX)P(Z|X),当QQP(ZX)P(Z|X)之前的差距很小时就可以用QQ来近似代替P(ZX)P(Z|X)

数学推导求解QQ

根据概率公式,我们有
logP(X)=logP(X,Z)logP(ZX)(1)\tag{1} logP(X) = logP(X, Z) - logP(Z|X)
(1)式右边的两项log里面的部分同时除以Q(Z)Q(Z),(1)式的值依然相等。即
logP(X)=logP(X,Z)Q(Z)logP(ZX)Q(Z)(2)\tag{2} logP(X) = log{P(X, Z) \over Q(Z)} - log{P(Z|X) \over Q(Z)}
(2)式两边同时乘以Q(Z)Q(Z),并对ZZ求积分得。

=ZlogP(X)Q(Z)dZ=logP(X)=Z(Q(Z)logP(X,Z)Q(Z)Q(Z)logP(ZX)Q(Z))dZ=ZQ(Z)logP(X,Z)Q(Z)dZZQ(Z)logP(ZX)Q(Z))dZ=L(Q)+KL(QP) \begin{aligned} 左边 = & \int_{Z} logP(X)Q(Z)dZ \\ & = logP(X) \\ 右边= & \int_{Z} ( Q(Z)log{P(X, Z) \over Q(Z)}- Q(Z)log{P(Z|X) \over Q(Z)})dZ \\ = & \int_{Z} Q(Z)log{P(X, Z) \over Q(Z)}dZ - \int_{Z} Q(Z)log{P(Z|X) \over Q(Z)})dZ \\ = & L(Q) + KL(Q||P) \end{aligned}
上式中,第一项记为L(Q)L(Q),第二项连同负号是KL散度的形式,表示的是后验分布P(ZX)P(Z|X)与分布Q(Z)Q(Z)之间的距离。
所以有
logP(X)=L(Q)+KL(QP)(3) \begin{aligned} \tag{3} logP(X) = L(Q) + KL(Q||P) \end{aligned}
由于KL(QP)KL(Q||P)是大于等于0的,所以L(Q)<=logP(X)L(Q) <= logP(X),此时也称L(Q)L(Q)是ELBO(证据下界)。值得注意的是,此时的L(Q)L(Q)是分布QQ的函数(也称为泛函),当数据给定的时候,logP(X)logP(X)是不变的,所以如果我们让L(Q)L(Q)最大化,也就是会让KL(QP)KL(Q||P)最小化,那么我们就会找到一个QQ,从而用这个QQ近似后验分布P(ZX)P(Z|X)

根据平均场理论,我们可以将ZZ分割为独立的MM份,也就是
Q(Z)=iMQi(Zi)(4)\tag{4} Q(Z) = \prod_{i}^{M} Q_{i}(Z_{i})
由前面的推导知,
L(Q)=ZQ(Z)logP(X,Z)Q(Z)dZ(5)\tag{5} L(Q) = \int_{Z} Q(Z)log{P(X, Z) \over Q(Z)}dZ
将(4)代入(5)式,得
L(Q)=ZiMQi(Zi)logP(X,Z)dZZiMQi(Zi)logiMQi(Zi)dZ=ZiMQi(Zi)logP(X,Z)dZZiMQi(Zi)iMlogQi(Zi)dZ(6) \begin{aligned} \tag{6} L(Q) = \int_{Z} \prod_{i}^{M}Q_{i}(Z_{i})logP(X,Z) dZ-\int_{Z} \prod_{i}^{M}Q_{i}(Z_{i})log\prod_{i}^{M}Q_{i}(Z_{i})dZ \\ = \int_{Z} \prod_{i}^{M}Q_{i}(Z_{i})logP(X,Z) dZ - \int_{Z} \prod_{i}^{M}Q_{i}(Z_{i})\sum_{i}^{M} logQ_{i}(Z_{i})dZ \end{aligned}
考虑其中的某一项Qj(Zj)Q_{j}(Z_{j}),分别计算(6)式的两部分。
变分推断一(基于平均场理论求解Q)
变分推断一(基于平均场理论求解Q)
变分推断一(基于平均场理论求解Q)
得到Qi(Zi)Q_{i}(Z_{i})之后,根据公式(4)便得到了我们要求的分布QQ