WBCE损失重写

WBCE

WBCE 即 weighted binary cross entropy,是 [1] 的公式 1,改版的 binary cross entropy。
Lwbce(y,z,w)=i=1c[wiyilogzi+(1yi)log(1zi)]L^{wbce}(y,z,w)=-\sum_{i=1}^c[w_i\cdot y_i \log z_i+(1-y_i)\log (1-z_i)]
其中,y 是真实 label 向量,z 是预测 label 向量,w 是权重向量,wi=#0{i}#1{i}w_i=\frac{\#0\{i\}}{\#1\{i\}}#0{i}\#0\{i\} 是 training set 中属于第 i 类的样本个数,#1{i}\#1\{i\} 类似地表示属于的。
wiw_i 是乘在中括号面的。

Reformulation

当 z 是 sigmoid 的输出(如 [1] 的公式 4),记 z=σ(x)z=\sigma(x)σ()\sigma(\cdot) 表示 sigmoid 函数),参照 [2],修改公式防溢出。
先按 y、z、w 都是标量的情况考虑:
Lwbce(y,z,w)=[wylogz+(1y)log(1z)]=wylogσ(x)(1y)log[1σ(x)]=wylog(1+ex)+(1y)[x+log(1+ex)]=(wy+1y)log(1+ex)+(1y)x\begin{aligned}L^{wbce}(y,z,w)&=-[w\cdot y \log z+(1-y)\log (1-z)] \\ &=-wy\log\sigma(x)-(1-y)\log[1-\sigma(x)] \\ &=wy\log(1+e^{-x}) + (1-y)[x+\log(1+e^{-x})] \\ &=(wy+1-y)\log(1+e^{-x})+(1-y)x \end{aligned}
因为 exe^{-x}x<0x<0 时很爆炸:
WBCE损失重写
x<0x<0 的情况特殊处理,变形:
Lwbce(y,z,w)=(wy+1y)log(1+ex)+(1y)x(1)=(wy+1y)log(1+ex)+(wy+1y)xwyx=(wy+1y)[log(1+ex)+logex]wyx=(wy+1y)log(1+ex)wyx(2)\begin{aligned}L^{wbce}(y,z,w)&=(wy+1-y)\log(1+e^{-x})+(1-y)x &(1) \\ &=(wy+1-y)\log(1+e^{-x})+(wy+1-y)x-wyx \\ &=(wy+1-y)[\log (1+e^{-x})+\log e^x]-wyx \\ &=(wy+1-y)\log (1+e^x)-wyx &(2) \end{aligned}
(1) 式适用 x0x\ge0 的情况,(2) 适用在 x<0x<0,于是:
Lwbce(y,z,w)={(wy+1y)log(1+ex)+(1y)x,x0(wy+1y)log(1+ex)wyx,x<0=(wy+1y)log(1+ex)+max{(1y)x,0}min{wyx,0}\begin{aligned} L^{wbce}(y,z,w)&=\left\{\begin{array}{cc} (wy+1-y)\log(1+e^{-x})+(1-y)x, & x\ge0 \\ (wy+1-y)\log (1+e^x)-wyx, & x<0 \end{array}\right. \\ &=(wy+1-y)\log (1+e^{-|x|})+\max\{(1-y)x,0\}-\min\{wyx,0\} \end{aligned}
后面 max 和 min 两项是因为 y{0,1}y\in\{0,1\}w0w\ge0,所以 (1y)x(1-y)xwyxwyx 的符号看 xx

References

  1. Semi-Supervised Cross-Modal Retrieval with Label Prediction
  2. tf.nn.sigmoid_cross_entropy_with_logits
  3. Giving reasons on each line of a sequence of equations