A Simple Framework for Contrastive Learning of Visual Representations

Abstract

提出了一种用于视觉表示的对比学习简单框架,并且结构简单,不需要专门的架构或特殊的存储库。
作者发现:
(1)数据增强的组合在定义有效的预测任务中起着至关重要的作用。
(2)在表示和对比损失之间引入可学习的非线性变换,大大提高了学习表示的质量。
(3)与监督式学习相比,对比式学习得益于更大的批量和更多的训练步骤,更深更宽的网络也更加有用。
(4)具有对比交叉熵损失的表示学习得益于归一化嵌入(normalized embeddings)和适当调整的温度(temperature)参数。
使用SimCLR训练的线性分类器在ImageNet上实现了超越之前所有半监督与自监督的方法,跟监督学习算法相媲美。

Method

The Contrastive Learning Framework

SimCLR通过潜在空间中的对比损失来最大化同一数据示例的不同增强视图之间的一致性,从而学习表示形式。
A Simple Framework for Contrastive Learning of Visual Representations
f(.)f(.)作者使用的是一个ResNet结构hi=f(x˜i)=ResNet(x˜i)h_i = f(x˜_i) = ResNet(x˜_i),hiRdh_i ∈ R^d
g(.)g(.)是一个MLP(多层感知机),使用contrastive loss,zi=g(hi)=W(2)σ(W(1)hi)z_i = g(h_i) = W^{(2)}σ(W^{(1)}h_i).σσ 是一个 ReLU nonlinearity.作者发现在ziz_i上定义contrastive loss比在hih_i上定义更好,优化使用的是线性学习率缩放的LARS(i.e. LearningRate = 0.3 × BatchSize/256),weight decay 为10610^{-6}。作者在4096的batchsize下训练了100个epoch,并在前10个epoch下使用了 linear warmup,

对比预测任务中的 contrastive loss function定义:给定一个包含positive pair x˜ix˜_ix˜jx˜_j的集合{x˜kx˜_k},对比预测任务就是给定x˜ix˜_i,然后从{x˜kx˜_k}k!=i_{k!=i}中鉴别出x˜jx˜_j。作者在一个minibatch中随机采样N个样本点,每个样本点做两次增强当作该样本点positive pair,其余2(N-1)个增强数据当作negative examples,loss函数的定义为:
A Simple Framework for Contrastive Learning of Visual Representations
作者称之为 NT-Xent(归一化温度尺度交叉熵损失)。其中sim(zi,zj)sim(z_i,z_j)是这两个向量之间的cos相似度。
A Simple Framework for Contrastive Learning of Visual Representations

Training with Large Batch Size

1.用大的batchsize。
2.加BN,LN啥的

Data Augmentation for Contrastive Representation Learning