变分推断二(基于随机梯度求解分布Q)

高方差的问题

根据上一节变分推断一(根据平均场理论求解Q)我们得到了需要求解的分布QQ的函数。
L(Q)=ZQ(Z)logP(X,Z)Q(Z)dZQ=EQ(Z)[logP(X,Z)logQ(Z)](1) \begin{aligned} \tag{1} L(Q) = & \int_{Z} Q(Z)log{P(X,Z) \over Q(Z)} dZ \\Q = & E_{Q(Z)}[logP(X, Z) - logQ(Z)] \end{aligned}
我们最终的目的是求解QQ,在实际中QQ分布是有参数的,参数记为φ\varphi,只要求解了参数φ\varphi,也就求得了分布QQ。因此我们可以将(1)式进一步写成关于未知参数φ\varphi的函数。即
L(φ)=Eqφ(z)[logp(xi,z)logqφ(z)](2)\tag{2}L(\varphi) = E_{q_{\varphi}(z)}[logp(x^{i},z) - logq_{\varphi}(z)]
其中xix^{i}表示第i个样本,并且将(1)式中的大写字母全部转化为小写。这对推导并没有影响。

既然题目是用梯度来求未知参数φ\varphi,那么就要对(2)式求关于φ\varphi的导数。
φL(φ)=φ(Eqφ(z)[logp(xi,z)logqφ(z)])=φzqφ(z)[logp(xi,z)logqφ(z)]dz=zφqφ(z)[logp(xi,z)logqφ(z)]dz+zqφ(z)φ[logp(xi,z)logqφ(z)]dz=A+B(3) \begin{aligned} \tag{3} \nabla_{\varphi}L(\varphi) = & \nabla_{\varphi}(E_{q_{\varphi}(z)}[logp(x^{i}, z)-logq_{\varphi}(z)]) \\ = & \nabla_{\varphi}\int_{z}q_{\varphi}(z)[logp(x^{i}, z) -logq_{\varphi}(z)]dz \\ = & \int_{z}\nabla_{\varphi}q_{\varphi}(z)[logp(x^{i}, z) - logq_{\varphi}(z)]dz \\ & +\int_{z}q_{\varphi}(z)\nabla_{\varphi}[logp(x^{i}, z)-logq_{\varphi}(z)]dz \\ = & A + B \end{aligned}
将(3)式第三行的两项分别记为ABA和B,接下来分别求解。
变分推断二(基于随机梯度求解分布Q)
所以最后L(φ)φL(\varphi)对\varphi的导数就可以用(4)式所示的期望来代替。这样我们就可以用蒙特卡洛模拟的方法,从qφ(z)q_{\varphi}(z)中采样若干个点,然后来近似(4)式的期望,也就是近似φL(φ)\nabla_{\varphi}L(\varphi)。这样就可以使用梯度下降的方法来更新φ\varphi,最后求得φ\varphi

上面的方法看似可以,但是仔细分析会存在一些问题。(4)式是函数φlogqφ(z)[logp(xi,z)logqφ(z)]\nabla_{\varphi}logq_{\varphi}(z)[logp(x^{i}, z)-logq_{\varphi}(z)]在分布qφ(z)q_{\varphi}(z)下的期望,但是logqφ(z)logq_{\varphi}(z)的梯度变化会非常大(log函数的图像是由陡变缓的)。假如采样了两个点z1,z2qφ(z1)0qφ(z2)1z_{1}, z_{2},但是q_{\varphi}(z_{1})接近0,而q_{\varphi}(z_{2})接近1,求导之后这两个点的梯度差是非常大的,所以会存在高方差的问题,高方差问题会导致在梯度更新时不稳定。所以就需要一种方法来降方差,使得梯度能稳定的更新。

重参数化降方差

关于重参数化技巧可以看苏剑林的科学空间漫谈重参数,讲解的很详细。
从(2)式可以看到,问题的根源是zz是从分布qφ(z)q_{\varphi}(z)中采样得到的,所以将(2)式转化为积分形式后(如(3)式所示),里面会出现qφ(z)q_{\varphi}(z),再对φ\varphi求导就会变成(4)式,里面就会出现一项φlogqφ(z)\nabla_{\varphi}logq_{\varphi}(z),这就会导致高方差的问题。

要是zz不从qφ(z)q_{\varphi}(z)中直接采样,而是从一个已知的分布p(ε)p(\varepsilon)中采样得到ε\varepsilon,再通过一个变换z=gφ(ε)z=g_{\varphi}(\varepsilon)得到zz,通过这样的过程来采样zz,将zz的随机性转化为ε\varepsilon的随机性,这样就消除了高方差的问题。下面就通过公式来体验一下。

已知:
εp(ε)z=gφ(ε)\varepsilon \thicksim p(\varepsilon),z = g_{\varphi}(\varepsilon)
变分推断二(基于随机梯度求解分布Q)
通过推导我们得到了(6)式。在计算时,先从p(ε)p(\varepsilon)中采样出ε1,,,εk\varepsilon^{1}, ,,\varepsilon^{k},对于某个εi\varepsilon^{i},求出zf(z)\nabla_{z}f(z)zf(z)\nabla_{z}f(z)中必定含有zz,再将z=gφ(εi)z=g_{\varphi}(\varepsilon^{i})带入计算,最后得到zf(zi)φgφ(εi)\nabla_{z}f(z^{i})\nabla_{\varphi}g_{\varphi}(\varepsilon^{i}),则
φL(φ)=1ki=1kzf(zi)φgφ(εi)φ(t+1)=φ(t)+λ(t)φL(φ) \begin{aligned} \nabla_{\varphi}L(\varphi)= & {1 \over k} \sum_{i = 1}^{k} \nabla_{z}f(z^{i})\nabla_{\varphi}g_{\varphi}(\varepsilon^{i}) \\ \varphi^{(t+1)}=& \varphi^{(t)} + \lambda^{(t)}\nabla_{\varphi}L(\varphi) \end{aligned}
通过上面的梯度更新,最后便可算出φ\varphi,也就求得了分布qφ(z)q_{\varphi}(z)。就可以用qφ(z)p(zx)q_{\varphi}(z)来近似代替后验分布p(z|x)

最后推荐苏剑林的科学空间中的有关博客和b站白板推导系列视频。