CONTRASTIVE REPRESENTATION DISTILLATION

Tian Y., Krishnan D., Isola P. CONTRASTIVE REPRESENTATION DISTILLATION. arXiv preprint arXiv 1910.10699, 2019.

感觉其和的相似度有50%, 不过这篇写得早一点, 所以后者是借鉴了这篇文章? 这篇文章总的来说就是将distillation 和 contrastive learning 结合起来.

主要内容

CONTRASTIVE REPRESENTATION DISTILLATION

思想便是, 希望 f S ( x i ) f^S(x_i) fS(xi)靠近 f T ( x i ) f^T(x_i) fT(xi), 而 f S ( x j ) f^S(x_j) fS(xj)远离 f T ( x i ) f^T(x_i) fT(xi). 定义
S : = f S ( x ) , T : = f T ( x ) . S:=f^S(x), \quad T:= f^T(x). S:=fS(x),T:=fT(x).
假设源于同一样本的联合分布 P ( S , T ∣ C = 1 ) P(S,T|C=1) P(S,TC=1) P 1 ( S , T ) P_1(S,T) P1(S,T), 而源于不同样本的联合分布 P ( S , T ∣ C = 0 ) P(S,T|C=0) P(S,TC=0) P 0 ( S ) P 0 ( T ) P_0(S)P_0(T) P0(S)P0(T). 则我们很自然地希望最大化互信息:
I ( S , T ) = E P 1 ( S , T ) log ⁡ P 1 ( S , T ) P 0 ( S ) P 0 ( T ) . I(S,T)= \mathbb{E}_{P_1(S,T)} \log \frac{P_1(S,T)}{P_0(S)P_0(T)}. I(S,T)=EP1(S,T)logP0(S)P0(T)P1(S,T).

接下来就是负采样和对比学习的东西了, 假设数据集是如此构造的: 一个特征 T T T, 以及N+1个特征 { S , S 1 , … , S N } \{S,S_1,\ldots, S_N\} {S,S1,,SN}, 其中 S , T S,T S,T构成正样本对(即来源于同一个样本, 其余 S i , T S_i,T Si,T构成负样本对. 则我们有先验
P ( C = 1 ) = 1 N + 1 , P ( C = 0 ) = N N + 1 . P(C=1)=\frac{1}{N+1}, P(C=0)=\frac{N}{N+1}. P(C=1)=N+11,P(C=0)=N+1N.
于是便有
P ( C = 1 ∣ T , S ) = P 1 ( T , S ) P 1 ( T , S ) + N P 0 ( T ) P 0 ( S ) , P(C=1|T,S)=\frac{P_1(T,S)}{P_1(T,S)+NP_0(T)P_0(S)}, P(C=1T,S)=P1(T,S)+NP0(T)P0(S)P1(T,S),

log ⁡ P ( C = 1 ∣ T , S ) = − log ⁡ ( 1 + N P 0 ( T ) P 0 ( S ) P 1 ( T , S ) ) ≤ − log ⁡ N + log ⁡ P 1 ( T , S ) P 0 ( T ) P 0 ( S ) . \begin{array}{ll} \log P(C=1|T,S) &= -\log (1+N\frac{P_0(T)P_0(S)}{P_1(T,S)}) \\ & \le -\log N + \log \frac{P_1(T,S)}{P_0(T)P_0(S)}. \end{array} logP(C=1T,S)=log(1+NP1(T,S)P0(T)P0(S))logN+logP0(T)P0(S)P1(T,S).
两边关于 P 1 ( T , S ) P_1(T,S) P1(T,S)求期望可知
I ( T , S ) ≥ log ⁡ N + E P 1 ( T , S ) log ⁡ P ( C = 1 ∣ T , S ) . I(T,S) \ge \log N + \mathbb{E}_{P_1(T, S)} \log P(C=1|T,S). I(T,S)logN+EP1(T,S)logP(C=1T,S).

但是 P ( C = 1 ∣ T , S ) P(C=1|T,S) P(C=1T,S)未知, 故作者采用 h ( T , S ) h(T,S) h(T,S)去拟合, 通过极大似然估计
L c r i t i c ( h ) = E P 1 ( T , S ) log ⁡ h ( T , S ) + N E P 0 ( T , S ) log ⁡ ( 1 − h ( T , S ) ) . \mathcal{L}_{critic}(h)= \mathbb{E}_{P_1(T,S)} \log h(T,S) + N \mathbb{E}_{P_0(T,S)}\log (1-h(T,S)). Lcritic(h)=EP1(T,S)logh(T,S)+NEP0(T,S)log(1h(T,S)).
只要 h h h的拟合能力够强, 最后便能很好的逼近 P ( C = 1 ∣ T , S ) P(C=1|T,S) P(C=1T,S). 设其最优解为 h ∗ h^* h. 但是需要注意的一点是, h ∗ h^* h T , S T, S T,S有关系, 则其隐式地和 f S f^S fS有关系, 而 f S f^S fS又需要
max ⁡ f S E P 1 log ⁡ h ∗ ( T , S ) , \max_{f^S} \mathbb{E}_{P_1} \log h^*(T,S), fSmaxEP1logh(T,S),
所以这就成了一个交替迭代的过程. 作者就另辟蹊径, 既然
I ( T , S ) ≥ log ⁡ N + E P 1 ( T , S ) log ⁡ h ∗ ( T , S ) + N E P 0 ( T , S ) log ⁡ ( 1 − h ∗ ( T , S ) ) ≥ log ⁡ N + E P 1 ( T , S ) log ⁡ h ( T , S ) + N E P 0 ( T , S ) log ⁡ ( 1 − h ( T , S ) ) . \begin{array}{ll} I(T,S) &\ge \log N + \mathbb{E}_{P_1(T,S)} \log h^*(T,S) + N \mathbb{E}_{P_0(T,S)}\log (1-h^*(T,S)) \\ & \ge \log N + \mathbb{E}_{P_1(T,S)} \log h(T,S) + N \mathbb{E}_{P_0(T,S)}\log (1-h(T,S)). \end{array} I(T,S)logN+EP1(T,S)logh(T,S)+NEP0(T,S)log(1h(T,S))logN+EP1(T,S)logh(T,S)+NEP0(T,S)log(1h(T,S)).

便不妨共同优化 f S , h f^S, h fS,h.

注: 第二个不等式成立, 因为 h ( T , S ) ∈ [ 0 , 1 ] h(T,S) \in [0, 1] h(T,S)[0,1], 故第二项非正.

文中取的 h h h
h ( T , S ) = e g T ( T ) ′ g S ( S ) / τ e g T ( T ) ′ g S ( S ) / τ + N M , h(T,S)=\frac{e^{g^T(T)'g^S(S)/\tau}}{e^{g^T(T)'g^S(S)/\tau} + \frac{N}{M}}, h(T,S)=egT(T)gS(S)/τ+MNegT(T)gS(S)/τ,
其中, g g g为一线性变换, τ \tau τ为temperature, M M M为整个数据集的大小.

超参数的选择

CIFAR100:
N: 16384
τ \tau τ: 0.1

代码

原文代码