Strategies For Pre-Training Graph Neural Networks
Paper : STRATEGIES FOR PRE-TRAINING GRAPH NEURAL NETWORKS
Code : official
摘要
作者解决的问题是如何预训练一个GNN网络,保证预训练的结果在具体数据集中finetune不会negative transfer 的现象。作者在文中并没有细致的解释为什么GNN上进行transfer learning 会更难,这个可能需要翻一下该篇文章的引用paper。作者提出了两步的预训练策略,分别在节点层面和图层面进行预训练,保证GNN预训练过程中可以同时学到局部和全局的信息,最后通过实验证明作者提出的transfer learning 策略可以有效地改善模型的表现。
两个级别
作者认为,节点层面的预训练可以很好的区分局部不同子图形状的节点,但是对于节点编码来说不可组合,全图层面的预训练可以很好区分图的编码,但是节点层面的编码不能表示局部的语义信息,高质量的预训练需要结合两部分。因此作者采用了两步预训练的方式,节点级别的预训练和图级别的预训练。
节点级别预训练
作者提出了两种自监督式的预训练方法,Context Prediction 和 Attribute Masking
Context Prediction
Context Prediction 使用子图来预测周围的图结构,优化目标是预训练一个GNN可以将具有相近结构的节点映射到相似的特征编码上。
K-hop Neighbourhood:距离节点 v 至多 k 的所有节点构成的子图
Context Graph : 到节点 v 的距离在 r1 ~ r2 之间所有的节点构成的多个图
Context Anchor Node : 取 r1 < K ,所有到节点 v 的距离在 r1 ~ K 之间的节点
作者使用辅助GNN来将context graphs 编码成一个固定长度的向量,步骤如下
- 使用context GNN 对每个context graph 进行编码,得到节点上的特征表示
- 对所有的context archor nodes 取平均,获得定长向量表示,设节点 的context embedding 表示为
使用负采样的方式来联合学习主干网络和context GNN,Context Prediction 的学习目标是一个二分类结果,即当节点 和 属于同一个节点时
而负样本通过对任意图中随机节点的采样获得,正负样本比率为 1:1。
Attribute Masking
Attribute Masking 旨在通过学习分布在图结构上的节点/边属性的规律性来捕获领域知识。
Attribute Masking 预训练的工作原理如下:屏蔽节点/边属性,然后让GNN根据相邻结构预测这些属性。具体来说,通过用特殊的屏蔽指示符替换输入节点/边属性(例如分子图中的原子类型)来随机屏蔽它们。然后,我们应用GNN获得相应的节点/边编码表示。最后,在嵌入的顶部应用线性模型以预测被掩盖的节点/边属性。我们对非完全连接图进行操作,旨在捕获分布在不同图结构上的节点/边属性的规则性。
图级别预训练
图级别的预训练可以有两个方向,对整张图的编码/属性进行预训练,或是对图结构进行预训练。
SUPERVISED GRAPH-LEVEL PROPERTY PREDICTION
作者表示,仅简单地执行广泛的多任务图级别预训练可能无法提供可迁移的图级别表示形式。这是因为某些受监督的预训练任务可能与下游任务无关。甚至会损害下游任务的表现。 一种解决方案是选择“真正相关的”有监督的预训练任务,并仅对那些任务进行预训练GNN。 但是,由于选择相关任务需要大量的领域专业知识,并且需要针对不同的下游任务分别进行预培训,因此这种解决方案的成本非常高。
作者认为,多任务监督式预训练仅提供图级别监督; 因此,从中创建图级嵌入的本地节点嵌入可能没有意义。由于许多不同的预训练任务可以更容易地在节点嵌入空间中相互干扰,因此这种无用的节点嵌入会加剧负迁移问题。 因此,我们的预训练策略是在执行图级预训练之前,首先通过节点级预训练方法在单个节点级别上对GNN进行正则化。 这种组合策略可以产生更多可迁移的图表示,并且可以在无需专家选择有监督的预训练任务的情况下稳健地提高下游性能。
STRUCTURAL SIMILARITY PREDICTION
作者认为找到ground truth 图距离值是一个难题,在大型数据集中要考虑的图对数量是平方级问题,因此不采用这种方法。
实验结果
根据实验结果,作者总结出以下几点
- 表示能力最强的GIN结构在迁移学习中获得了最大的提升
- 只在GNN上进行图级别的迁移学习提升有限,甚至会发生负向迁移
- 只做节点级别的迁移学习提升同样受限
- 节点级别和图级别的迁移学习达到了最好的结果
- MUTAG, PTC molecule datasets 等小数据集不适合进行性能比较
- Pretrained数据集收敛速度更快
总结
作者提出了该方法未来的几个改进点
- 改进GNN的结构以及预训练和微调的方法,以进一步提高泛化能力。
- 研究预训练模型学习到了什么有用的信息。
- 将这一方法应用于其他领域,例如物理、材料科学、生物结构等。
- 图级别预测的预训练任务中,是否能增加图结构相似性的预测任务。
总的来说,作者提出了一种GNN上预训练的方法,但是我个人认为没有仔细分析为什么GNN上进行迁移学习相比CNN来说效果会差很多,这是一个可以做进一步研究的地方。作者的想法无论是node level还是graph level有点过于直觉了。