Paper : Contrastive Learning Of Structured World Models
Code : official
摘要
作者在强化学习任务中,提出了使用GNN来建模多物体之间的关系的方法,使用transE的思路来进行状态的预测,并通过Contrastive Loss 来进行训练,提出 C-SWM 模型。作者在雅达利游戏,多物体物理模拟等任务上达到了SOTA。
方法
C-SWM模型可以对物体与物体之间的关系建边,目的是对每个物体学习到它的特征表示,通过Contrastive Loss 对模型进行自监督学习。假定强化学习过程中环境状态和操作序列为 B={(st,at,st+1)}t=1T ,强化学习的目标是学习环境状态 st 对应的隐变量 zt∈Z ,来预测执行了操作 at 之后的状态 st+1 ,因此需要两个函数 E:S→Z,T:Z×A→Z
仿照TransE的方法对损失函数进行设计
L=d(zt+T(zt,at),zt+1)+max(0,γ−d(zt,zt+1))
其中 zt=E(st),st 是从经验序列中随机抽取的状态,表示负例。
C-SWM模型的整体架构如下所示

对象提取器和编码器:将编码器分为两个单独的模块:1)基于CNN的对象提取器Eext,2)基于MLP的编码器Eenc。对象提取器是一个CNN,最后一层具有 K 个特征图。每个特征图 mtk=[Eext(st)]k 可以解释为对应于一个特定对象的mask,其中 [...]k 表示选择第k个特征图。为简单起见,我们仅为每个对象分配一个特征图。为了允许编码更复杂的对象特征,对象提取器可以适于在每个对象槽中产生多个特征图。在对象提取器模块之后,我们对每个特征图 mtk 进行平整并将其送到对象编码器Eenc 中。对象编码器在各个对象之间共享权重,并返回抽象状态表示形式:ztk=Eenc(mtk) ,其中ztk∈Zk。我们在下面设置Zk=RD,其中D是超参数。
关系转换模型:将转换模型实现为图神经网络,它使我们能够对对象状态之间的成对交互进行建模,同时不影响对象表示的顺序。在编码器阶段之后,我们对场景中的每个对象都有一个抽象的状态描述ztk∈Zk 和一个动作 atk∈Ak。我们将动作表示为one-hot 向量,然后,转换函数将对象表示的元组 zt=(zt1,...,ztK) 和动作 at=(at1,...atK) 在特定的时间步转化为
Δzt=T(zt,at)=GNN({(ztk,atk)}k=1K)
其中 K 表示场景中对象的个数,下一个时间步中状态更新为
zt+1=(zt1+Δzt1,...,ztK+ΔztK)
其中边特征向量的编码器和增量 Δzt 的解码器定义为
et(i,j)=fedge([zti,ztj])Δztj=fnode([ztj,atj,i=j∑et(i,j)])
对于多物体来说,损失函数定义为
L=H+max(0,γ−H)H=K1k=1∑Kd(ztk+Tk(zt,at),zt+1k)H=K1k=1∑Kd(ztk,zt+1k)
算法限制
歧义消歧:对象提取器模块是一种简单的前馈CNN架构。这种类型的架构无法消除一个场景中存在的同一对象的多个实例的歧义,而是依赖于不同的视觉特征或标签进行对象提取。为了更好地处理可能包含同一对象的多个副本的场景,需要某种形式的迭代消歧过程来打破对称性,并将单个对象动态绑定到插槽或对象文件。
随机性和马尔可夫假设:C-SWM公式没有考虑环境转变或观测中的随机性,因此仅限于完全确定性的世界。C-SWM的概率扩展是将来工作之一。为简单起见,我们做出马尔可夫假设:状态和动作包含预测下一个状态所需的所有信息。这使我们可以孤立地查看单个状态-动作-状态三元组。为了超越这一限制,将需要某种形式的存储机制,例如作为模型体系结构一部分的RNN。
结论
该文章在使用方法本身,创新性有限,不过开辟新的解决思路,实验结果也还不错。在强化学习的应用中表明图神经网络是学习多物体动态变化的比较有力的方法,可以泛化到其他多物体的任务上去。