SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019

论文:Self-Attention Graph Pooling

作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
韩国首尔高丽大学计算机科学与工程系

来源:ICML 2019

论文链接:
Arxiv: https://arxiv.org/abs/1904.08082

github链接:https://github.com/inyeoplee77/SAGPool

近年来,人们提出了将深度学习应用于图数据等结构化数据的方法。研究工作集中在将CNN泛化到图数据,重新定义图数据的卷积操作和downsampling(池化)操作。将卷积运算推广到图上的方法已被证明了可以提高性能,并得到了广泛的应用。然而,将downsampling应用于图数据的方法仍然难以实现,还有改进的空间。此文提出了一种基于self-attention的图池化方法。使用图卷积的self-attention使得化池方法同时考虑了节点特征和图的拓扑结构。实验中,为了确保公平的比较,将现有的池化方法和文中提出的SAGPool方法使用了相同的训练流程和模型架构。实验结果表明,该方法使用合理数量的参数,在基准数据集上获得了较好的图分类性能。

1 相关介绍

背景

目前,图池化的方法比图卷积的方法要少,而现存的基于池化的方法存在一些问题:

  • 以往的基于拓扑的池化的研究都只考虑了图的拓扑结构。
  • 而全局池化方法只考虑了图的特征。
  • 分层池化可以学到图的层次表示(DIFFPOOL),允许图神经网络(GNNs)以端到端方式汇聚后获得按比例缩小的图,但是具有二次方的空间复杂度,其参数数量取决于节点数量。分层池化gPool解决了复杂度问题,但没有考虑图的拓扑结构。

创新性

文中提出了SAGPool,这是一种基于层次图池化的Self-Attention Graph方法。

  • SAGPool方法可以使用相对较少的参数以端到端方式学习分层表示
  • 利用self-attention机制来区分应该删除的节点和应该保留的节点
  • 基于图卷积计算注意力分数的self-attention机制,考虑了节点特征和图的拓扑结构

简而言之,SAGPool具有前几种方法的优点:分层池化,同时考虑节点特征和图的拓扑结构,合理的复杂度,以及端到端表示学习。SAGPool是第一个使用self-attention进行图池化处理并实现高性能的方法。SAGPool参数量一致,不用考虑输入图的大小。

2 相关工作

GNNs因其在图领域中以最先进的表现而备受关注。CNN模型中的池化层通过缩小representations的大小来减少参数的数量,从而避免过拟合问题。为了将CNNs推广到图数据上,GNNs使用池化方法是必要的。图数据池化方法可以分为以下三类:基于拓扑的池化、全局池化和分层池化。

基于拓扑的池化

早期的工作使用的是图的粗化算法,而不是使用神经网络。谱聚类算法利用特征分解得到粗化图。然而,由于特征分解的时间复杂度,需要一些替代方法:
(1) Weighted graph cuts without eigenvectors a multilevel approach, 2007
Graclus计算了给定图无特征向量的的聚类,因为一般谱聚类目标和加权核k-means目标之间存在数学等价性。
(2)在最近的GNN模型中Graclus被用作池化模块:

  • Convolutional neural networks on graphs with fast localized spectral filtering,NIPS 2016
  • Hybrid approach of relation network and localized graph convolutional filtering for breast cancer subtype classification,IJCAI-18

全局池化

与前面的方法不同,全局池方法考虑了图的特征。全局池方法使用求和或神经网络pool每个层中所有节点的表示。全局池方法pool所有的表示,可以处理具有不同结构的图:
(1)Neural message passing for quantum chemistry,2017)
将GNNs视为消息传递方案,提出了一种图:分类的通用框架,利用Set2Set方法可以获得整个图的表示。
(2)An end-to-end deep learning architecture for graph classification,AAAI 2018
SortPool:根据图的结构对节点的embeddings进行排序,并将排序后的embeddings传递给下一层。

分层池化

全局池方法没有学习对捕获图结构信息至关重要的层次表示。分层池化方法的主要动机是建立一个能够学习每一层中基于特征或拓扑的节点assignment的模型:
(1)[DIFFPOOL] Hierarchical Graph Representation Learning with Differentiable Pooling,NeurIPS 2018
DIFFPOOL是一种可微的图池方法,能够以端到端的方式学习assignment矩阵:S(l)Rnl×nl+1S^{(l)} \in \mathbb{R}^{n_{l} \times n_{l+1}}nln_l表示在第ll层的节点数,nl+1n_{l+1}表示在第l+1l+1层的节点数,nl<nl+1n_l < n_{l+1},即行数等于第ll层的节点数(cluster数),列数代表第l+1l+1层的节点数(cluster数)。assignment matrix表示第ll层的每一个节点到第l+1l+1层的每一个节点(或cluster)的概率。

具体而言,节点根据下面的公式分到下一层的cluster:

S(l)=softmax(GNNl(A(l),X(l))) S^{(l)}=\operatorname{softmax}\left(\mathrm{GNN}_{l}\left(A^{(l)}, X^{(l)}\right)\right)

A(l+1)=S(l)A(l)S(l)(1) \tag{1} A^{(l+1)}=S^{(l) \top} A^{(l)} S^{(l)}
具体细节,可以参考另一篇博文:[DIFFPOOL 图分类] - Hierarchical Graph Representation Learning with Differentiable Pooling NeurIPS 2018

(2)Graph u-net,ICML 2019
gPool实现了与DiffPool相当的性能。gPool需要O(V+E)O(|V| + |E|)的空间复杂度,而DiffPool需要O(kV2)O(k|V|^2),其中V,E,kV,E,k分别表示顶点数、边数和池化比率。gPool使用一个可学习的向量pp来计算投影分数,然后使用这些分数来选择排名最高的节点。投影分数由pp与各节点特征的点积得到。分数表示可以保留的节点信息量。下式大致描述了gPool中的池化过程:

y=X(l)p(l)/p(l), id x=toprank(y,kN) y=X^{(l)} \mathbf{p}^{(l)} /\left\|\mathbf{p}^{(l)}\right\|, \quad \text { id } \mathbf{x}=\operatorname{top}-\operatorname{rank}(y,\lceil k N\rceil)

A(l+1)=Aidx,idx(l)(2) \tag{2} A^{(l+1)}=A_{\mathrm{idx}, \mathrm{idx}}^{(l)}
在公式(2)中,投影分数的计算没有考虑图的拓扑结构。

为了进一步改进图池化方法,文中提出了SAGPool,它可以使用图的特征和拓扑结构信息来产生具有合理的时间和空间复杂度的层次表示。

3 方法

SAGPool的关键在于它使用GNN来提供self-attention分数。

3.1 基于self-attention的图池化方法:SAGPool

SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019
  • 图1是SAGPool层的结构图
Self-attention mask

注意机制在最近的深度学习研究中被广泛应用。这种机制可以使我们能够关注更重要的特征。特别是,self-attention,通常被称为intra-attention,允许关注的特征是注意力本身。SAGPool利用图卷积的方法得到self-attention分数。例如,如果使用Kipf的图卷积公式,则self-attention分数ZRN×1Z \in \mathbb{R}^{N \times 1}根据如下计算:

Z=σ(D~12A~D~12XΘatt)(3) \tag{3} Z=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X \Theta_{a t t}\right)

  • ΘattRF×1\Theta_{a t t} \in \mathbb{R}^{F \times 1}是SAGPool层中的唯一参数
  • 其他符号表示和GCN中相同

通过利用图卷积得到self-attention分数,池的结果是基于图的特征和拓扑的。SAGPool采用了gPool中的节点选择方法,这种方法保留了输入图的一部分节点,即使在输入不同大小和结构的图时也是如此。池化比率k(0,1]k \in(0,1]是一个超参数,它决定要保留的节点数。topkN\lceil k N\rceil的节点是根据ZZ的值来选择的。

idx=toprank(Z,kN),Zmask=Zidx(4) \tag{4} \mathrm{idx}=\operatorname{top}-\operatorname{rank}(Z,\lceil k N\rceil), \quad Z_{m a s k}=Z_{\mathrm{idx}}

  • 其中toprank\operatorname{top}-\operatorname{rank}是返回topkN\lceil k N\rceil的节点的索引的函数
  • .idx._{idx}是一个索引操作
  • ZmaskZ_{m a s k}是一个特征attention mask
图池化

输入图由图1中标记为masking的操作处理。

X=Xidx,i,Xout=XZmask,Aout=Aidx,idx(5) \tag{5} X^{\prime}=X_{\mathrm{idx}, \mathrm{i}}, \quad X_{o u t}=X^{\prime} \odot Z_{m a s k}, \quad A_{o u t}=A_{\mathrm{idx}, \mathrm{idx}}

  • Xidx,:X_{\mathrm{idx}, \mathrm{:}}是索引按行(按节点)排列的特征矩阵
  • XoutX_{o u t}是新的特征矩阵
  • AoutA_{o u t}是新的邻接矩阵
  • Aidx,idxA_{\mathrm{idx}, \mathrm{idx}}是按行和按列索引的邻接矩阵
SAGPool的变种

SAGPool中使用图卷积的主要原因是为了反映拓扑结构和节点特征。如果GNNs以节点特征和邻接矩阵为输入,则可以用式(3)代替GNNs的各种公式。计算注意力分数ZRN×1Z \in \mathbb{R}^{N \times 1}的广义方程如下:

Z=σ(GNN(X,A))(6) \tag{6} Z=\sigma(\operatorname{GNN}(X, A))

计算注意力分数,不仅可以使用相邻节点,也可以使用多跳连接的节点。可以使用添加改变邻接矩阵形式扩展边,堆叠多层GNN层,使用多个注意力分数的平均值等方法来达到这个目的。
以一个连接两跳的节点为例。
(1)添加邻接矩阵的平方: SAGPoolaugmentation\text { SAGPool}_{\text {augmentation}}
式(7)使用了两跳连接,该连接涉及边的扩展,允许两跳节点的间接聚合。添加邻接矩阵的平方相当于在两跳邻居之间创建了边:

Z=σ(GNN(X,A+A2))(7) \tag{7} Z=\sigma\left(\operatorname{GNN}\left(X, A+A^{2}\right)\right)

(2)叠加两层GNN层: SAGPoolserial\text { SAGPool}_{\text {serial}}
式(8)使用了两跳连接,该连接涉及GNN层的堆叠,允许两跳节点的间接聚合。在这种情况下,SAGPool层的非线性和参数数量将增加:

Z=σ(GNN2(σ(GNN1(X,A)),A))(8) \tag{8} Z=\sigma\left(\operatorname{GNN}_{2}\left(\sigma\left(\operatorname{GNN}_{1}(X, A)\right), A\right)\right)
公式(7)和公式(8)可以应用到更多跳的连接上。

(3)多个注意力分数的平均值: SAGPoolparallel\text { SAGPool}_{\text {parallel}}
另一种方法是多个注意力得分的平均值。MM个GNNs平均注意力分值如下:

Z=1Mmσ(GNNm(X,A))(9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right)

文中分别将公式(7),(8),(9)中的模型称为 SAGPoolaugmentation\text { SAGPool}_{\text {augmentation}} SAGPoolserial\text { SAGPool}_{\text {serial}} SAGPoolparallel\text { SAGPool}_{\text {parallel}}

3.2 模型架构

SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019
  • 图2

根据(Troubling trends in machine learning scholarship,2018)的观点,如果对一个模型进行了多次修改,那么就很难确定哪些修改有助于提高性能。为了公平的比较,文中采用了SortPool(AAAI 2018)和(Towards sparse hierarchical graph classifiers,2018)中的模型架构,并使用相同的架构来比较baseline和文中的方法。

卷积层

使用Kipf的GCN:
h(l+1)=σ(D~12A~D~12h(l)Θ)(10) \tag{10} h^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} \Theta\right)

  • ΘRF×F\Theta \in \mathbb{R}^{F \times F^{\prime}}
  • FFFF^{\prime}分别表示第l+1l+1层的输入特征维度和输出特征维度
  • **函数使用ReLU
readout层

受JK-net架构(Representation learning on graphs with jumping knowledge networks,2018;Towards sparse hierarchical graph classifiers,2018)的启发,提出了一种readout层,该层聚合节点特征以形成固定大小的表示。readout层的输出特征如下:

s=1Ni=1Nximaxi=1Nxi(11) \tag{11} s=\frac{1}{N} \sum_{i=1}^{N} x_{i} \| \max _{i=1}^{N} x_{i}

  • NN表示节点数量
  • xix_i表示第ii个节点的特征
  • ||,表示concatenation,即特征串联操作
全局池化架构

实现了(An end-to-end deep learning architecture for graph classification,AAAI 2018)中提出的全局池化架构。
如图2所示,全局池化结构由三个图卷积层组成,每层的输出被连接起来。节点特征在池化层之后的readout层中聚合。然后将图的特征表示传递到线性层进行分类。

分层池化架构

实现了(Towards sparse hierarchical graph classifiers,2018)最近的分层池化研究中的架构。如图2所示,架构由三个block组成,每个block由一个图卷积层和一个图池化层组成。每个block的输出汇总在readout层中。将每个readout层输出的总和输入到线性层进行分类。

4 实验

在图分类任务中,评估了全局池化和分层池化方法。

数据集

选取了5个图的数量较大的数据集(>1k> 1k):

  • D&D
    包含蛋白质结构的图。一个节点表示一个氨基酸,如果两个节点之间的距离小于6A^6 \hat{\mathrm{A}},则构造一挑边。标签表示蛋白质是酶还是非酶。
  • PROTEINS
  • 也是一组蛋白质,其中节点是二级结构元素。如果节点具有边,则节点处于氨基酸序列中或在封闭的三维空间中。
  • NCI1
    是一个用于抗癌活性分类的生物数据集。在数据集中,每个图表示一个化合物,节点和边分别表示原子和化学键。
  • NCI109
  • FRANKENSTEIN
    是一组具有包含连续值的节点特征的分子图。标签表示一个分子是诱变剂还是非诱变剂。
SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019

GNNs的评估

  • 所有模型均采用相同的early stopping准则和超参数选择策略,以保证比较的公平性。
  • 使用NVIDIA TitanXp GPU
  • 使用几何深度学习扩展库PyG实现所有的baselines和SAGPool

训练过程

  • (Pitfalls of graph neural network evaluation,2018)证明了不同的数据分割会影响GNN模型的性能。
  • 在实验中,使用10-fold交叉验证评估了超过20个随机种子的池化方法。
  • 总共使用了200个测试结果来获得每个数据集上每个方法的最终精度。
  • 10%的训练数据在训练过程中用于验证。使用了Adam优化器、early stopping准则、patience以及全局池化结构和分层池化结构的超参数选择策略。如果在epoch终止条件下最多100k个epoch验证损失没有改善,将停止训练。
  • 通过网格搜索得到最优超参数。网格搜索的范围如表2所示。
SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019

Baselines

  • Set2Set
    Set2Set需要额外的超参数,即LSTM模块的处理step数。假设readout层是不必要的,因为LSTM模块为节点顺序不变的图生成embedding。
  • SortPool
    SortPool是一种全局池化方法,它使用排序来池化。设置了节点数量为KK,使得60%的图有多于KK个的节点。在全局池化设置中, SAGPool g\text { SAGPool }_{g}具有与SortPool相同的KK个输出节点。
  • DiffPool
    DiffPool是第一种端到端可训练的图池化方法,它可以生成图的层次表示。使用中没有对DiffPool使用batch normalization,因为这与池化方法无关。对于超参数搜索,池化比率从0.25到0.5不等。在引用实现中,将cluster大小设置为节点的最大数目的25%。当池化比率大于0.5时,DiffPoolh会导致内存不足。
  • gPool
    gPool为池化选择排名靠前的节点,与SAGPool方法类似。通过与gPool的比较表明,考虑拓扑结构有助于提高图形分类任务的性能。

DiffPool,gPool和 SAGPool h\text { SAGPool }_{h}使用分层池化架构;
Set2Set, SortPool和 SAGPool g\text { SAGPool }_{g}使用全局池化架构;
对所有的baselines和SAGPool使用相同的超参数搜索策略。超参数如表2所示。

SAGPool的变种

  • SAGPool的三个变体用于获得注意力得分ZZ
  •  SAGPool augmentation\text { SAGPool }_{augmentation}使用公式(7)
  •  SAGPool serial\text { SAGPool }_{serial}使用公式(8)
  •  SAGPool paralle\text { SAGPool }_{paralle}使用多个GNN计算注意得分,对得分求平均以获得最终注意得分
  • 使用方程(9)评估M=2和M=4的性能

Z=1Mmσ(GNNm(X,A))(9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right)

结果总结

SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019
  • 结果表明,SAGPool在总体上表现良好,在D&D和PROTEINS方面表现尤为突出。
  • 在实验中,SAGPool在所有的数据集上都优于分层池化的方法。
  • SAGPool变体的实验结果表明,SAGPool具有提高性能的潜力。

5 分析

全局池化和分层池化

  • 很难确定全局池化结构或层次池化结构是否完全有利于图形分类。
  • 因为全局池化结构POOLgP O O L_{g} SAGPool g\text { SAGPool }_{g} SortPool g\text { SortPool }_{g} Set2Set g\text { Set2Set }_{g})使信息丢失最小化,因此它在节点较少的数据集(NCI1、NCI109、FRANKENSTEIN)上的性能优于分层池化结构POOLhP O O L_{h} SAGPool h\text { SAGPool }_{h} gPool h\text { gPool }_{h} DiffPool h\text { DiffPool }_{h})。
  • 然而,POOLhP O O L_{h}对节点数较多的数据集(D&D和PROTEINS)更有效,因为它能有效地从大规模图中提取有用的信息。
  • 因此,使用最适合给定数据的池化结构非常重要。
  • 尽管如此,SAGPool在每种架构中通常都表现良好。

考虑图拓扑结构的影响

  • 为了计算节点的注意力分数, SAGPool h\text { SAGPool }_{h}利用公式(3)所示的图卷积。
  • 和gPool不一样, SAGPool使用一阶近似图的拉普拉斯算子D~12A~D~12\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}},这使得SAGPool考虑了图的拓扑结构。如表3所示,考虑图的拓扑结构可以提高性能。此外,图的Laplacian算子不需要重新计算,因为它在前一个图卷积层中也使用了,可以预先计算。
  • 虽然SAGPool具有与gPool相同的参数,但它在图分类任务中表现出了更优异的性能。
SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019
  • 图3:图中参数的数量随着节点的增多而增大。
  • x轴标号为输入图节点数
  • y轴为分层池化模型参数:输入节点特征数为128,隐含层特征大小为128,class数为2。
  • SAGPool使用公式(3)的图卷积。
  • kk表示池化比率,k=1.0k = 1.0表示池化后保留整个节点。
  • 无论输入图的大小和池化比率如何,gPool和SAGPool参数数量都一致。

稀疏实现

  • 使用稀疏矩阵操作图数据对于GNNs来说非常重要,因为邻接矩阵通常是稀疏的。
  • 用稠密矩阵计算图卷积时,乘法AXAX的计算复杂度为O(V2)O(|V|^2),其中AA为邻接矩阵,XX为节点特征矩阵,VV为顶点。
  • 如(Towards sparse hierarchical graph classifiers,2018)所述,密集矩阵池化会导致内存效率问题。
  • 如果在同一操作中使用稀疏矩阵,则计算复杂度降低到O(E)O(|E|),其中EE表示边。由于SAGPool是一种稀疏池化方法,使用稀疏实现可以降低计算复杂度,而DiffPool是一种密集池化方法,计算复杂度较高。
  • 稀疏性也影响空间复杂性。因为SAGPool使用GNN来获取注意力分数,所以SAGPool需要O(V+E)O(|V |+|E|)的稀疏池化存储空间,而稠密池化方法需要O(V2)O(|V|^2)

节点数量的关系

在DiffPool中,由于GNN产生了如式(1)所示的assignment矩阵S,因此在构建模型时必须定义cluster的大小。根据参考方法的实现,cluster的大小必须与最大节点数成比例。DiffPool的这些要求会导致两个问题。

  • 参数的数量取决于最大节点数,如图3所示。
  • 当节点数量变化很大时,很难确定正确的cluster大小。例如,在1178个图中只有10个图具有超过1000个节点,其中最大节点数为5748,最小节点数为30。如果池化比率为10%,则cluster大小为574,这将扩展大多数数据池化后的图大小。

在SAGPool中,参数的数量与cluster的大小无关。此外,可以根据输入节点的数量更改cluster大小。

SAGPool变种比较

为了研究SAGPool方法的潜力,在两个数据集上评估了SAGPool的变种。可以用以下操作修改SAGPool:

  • 更改GNN的类型
  • 考虑两跳连接。实验中使用两个连续GNN层( SAGPool serial\text { SAGPool }_{serial})和添加邻接矩阵的方式来实现2跳连接( SAGPool augmentation\text { SAGPool }_{augmentation})。
  • 对多个GNN的注意力得分求平均值
SAGPool - Self-Attention Graph Pooling 图分类 图池化方法 ICML 2019
  • 使用的数据集和SAGPool中的GNN类型不同,图分类的性能有所不同
  • 两跳邻居的信息有助于提高性能
  • 发现为数据集选择合适的MM值有助于实现稳定的性能

当前方法的局限性

  • 保留一定比例(池化比率kk)的节点来处理不同大小的不同输入图,这在之前的研究中也做过
  • 在SAGPool中,无法为每个图将池化比率参数化来找到最优值。为了解决这个问题,文中使用二分类来决定保留哪些节点,但是这并没有完全解决问题。

可能的扩展工作

  • 为每个图使用可需要的池化比率来获得最优的cluster大小
  • 研究每个池化层中多attention mask的影响