【论文笔记】2017-NIPS-Causal Effect Inference with Deep Latent-Variable Models
Causal Effect Inference with Deep Latent-Variable Models
笔者最近在做causal inference这个方向,因此会把日常读到的还(neng)不(kan)错(dong)的paper简单整理一下做个笔记,欢迎感兴趣的童鞋交流讨论~
背景
Causal inference涉及到的数据集通常由三个变量组成 { X , T , Y } \left\{X,T,Y\right\} {X,T,Y}。其中, X X X代表特征(covariate),例如病人的身体、经济状况, T T T代表某个操作(treatment),通常是0-1的,例如是否服用某种药物, Y Y Y代表输出(outcome),例如病人一段时间后的血压血糖水平。简单来说,causal inference的任务是想在给定 X X X的情况下估计 T T T对 Y Y Y的影响。
本文作者考虑的一个问题是如何消除hidden confounder对causal inference的影响。简单来说,confounder
Z
Z
Z就是会对
T
T
T和
Y
Y
Y都产生影响的变量,例如一个人的经济实力社会地位,这些会对他是否能够服用某种药物产生影响,但
Z
Z
Z又是一个很难准确观测的变量。这里,作者假设可观测到的
X
X
X是
Z
Z
Z的代理变量,例如我们虽然很难准确度量一个人的社会地位
Z
Z
Z,但可以通过调查他的职业收入
X
X
X侧面反映
Z
Z
Z。这里,作者构建了一个如下的因果图:
这个因果图理解起来也比较直观,深色的是可以观测到的,白色的是无法观测到的,
X
X
X是
Z
Z
Z的一个noisy observation,因此
Z
→
X
Z\rightarrow X
Z→X,其他几个箭头都是causal inference里的常用假设。
方法
其实看到这个因果图,熟悉VAE的童鞋可能已经猜到了作者的思路,就是把 Z Z Z当做隐空间表示,然后套用VAE的架构。
Encoder:文章里叫inference network,结构如下图:
这个结构作者参考的是TARnet网络,这是causal inference里一个非常经典的深度模型,会在之后的博客里介绍。
q
(
t
∣
x
)
q(t|x)
q(t∣x)是在计算propensity score(不过这个东西在原始TARnet并没有用到,估计是作者为了实验效果后加上去的一项),在学完共同特征表示之后,根据
t
=
0
/
1
t=0/1
t=0/1接出两个分支。
Decoder:文章中叫model network,结构如下:
这个结构可以根据之前的因果图分解得到:
p
(
x
,
t
,
y
,
z
)
=
p
(
z
)
p
(
t
,
x
,
y
∣
z
)
=
p
(
z
)
p
(
t
,
y
∣
z
)
p
(
x
∣
t
,
y
,
z
)
=
p
(
z
)
p
(
t
,
y
∣
z
)
p
(
x
∣
z
)
=
p
(
z
)
p
(
t
∣
z
)
p
(
y
∣
t
,
z
)
p
(
x
∣
z
)
p(x,t,y,z)=p(z)p(t,x,y|z)=p(z)p(t,y|z)p(x|t,y,z)=p(z)p(t,y|z)p(x|z)=p(z)p(t|z)p(y|t,z)p(x|z)
p(x,t,y,z)=p(z)p(t,x,y∣z)=p(z)p(t,y∣z)p(x∣t,y,z)=p(z)p(t,y∣z)p(x∣z)=p(z)p(t∣z)p(y∣t,z)p(x∣z)
目标函数的推导与VAE基本一致:
当然,就像笔者之前提到的,为了实验效果,作者又在原始VAE loss上加了新的两项:
结论
这应该是第一篇利用深度生成模型求解causal inference的文章,文章的motivation(解决hidden confounder)和构建因果图的方式( X X X是 Z Z Z的noisy observation)很让人信服,不过实验效果好像一般(hhh可能也是因为如此大家都喜欢把它当做baseline),套用VAE的框架也不算难,读起来也比较轻松。