STG2Seq:多步乘车需求预测的时空图序列模型
新南威尔士大学发表在IJCAI 2019的一篇论文,题目标题为STG2Seq: Spatial-temporal Graph to Sequence Model for Multi-step Passenger Demand Forecasting,谷歌学术目前引用量为10。
Abstract
现存的问题:
多步乘车需求预测是车辆共享服务中的一个重要问题,其非线性和动态的时空依赖性具有挑战性。
本文的解决方案:
本文提出一种基于图的模型来建模城市的多步乘车需求预测,并使用分层的卷积结构来同时捕获空间和时间关联。本文提出的模型包括三个部分:(1)一个历史的长期编码器来对历史乘客需求进行建模;(2)产生多步预测的下一步预测的短期编码器;(3)一个基于注意力机制的输出模块来建模动态时间和通道间的信息。
Introduction
预测乘客需求预测存在两个挑战,非线性和动态时空关联。因为未来某一时刻的乘车需求不仅受该区域历史需求的影响,也受城市其他区域的影响。
传统的方法使用RNN及其变体LSTM或GRU来捕获时间关联,基于CNN来捕获空间关联,这些方法存在一些局限:
- 基于CNN的方法,包括ConvLSTM,通常把城市划分成规则的网格区域,但这并不总是成立的。这些方法只能捕捉到欧式空间中相邻区域的影响,而非欧式空间中距离远的区域不能很好捕获其关联性(CNN卷积核每次只能对局部进行特征提取)。
- 现有的方法严重依赖RNN类似的迭代结构来捕获时间关联,这在长距离的时间序列中会造成信息的损失,和错误的累计(RNN固有的缺陷,梯度消失和多步预测错误累计)
- 现有的方法没有捕获时间关联的动态依赖性,只能反应历史数据的集体影响。
上图显示了时间依赖的动态性,t1,t2,t2,t3对t4的影响各不相同,与t5,t6,t7对t8的影响模式也不同。
本文的贡献:
- 提出在一张图表上列出了全市的旅客需求,使用基于GCN的sequence to
sequence模型来建模城市范围的多步乘车需求预测。(首次应用图卷积用于多步预测) - 提出了一个基于注意力的输出模块,以捕捉最具影响力的历史时间步长对预测需求的影响以及这些关系所固有的动态性。
Methodology
问题定义:
将城市划分成N个区域,为t时刻的N个区域,2D矩阵∈代表t时刻的城市乘客需求;
向量∈代表时间步t的特征,包括一天中的时间,一周中的第几天,和节假日信息。
文章的问题就是给定历史乘客需求序列和时间特征来学习一个预测函数,预测未来τ步时间的乘车需求量:
定义图,v是节点集合,代表N个区域,ξ是边集合,A是领接矩阵,邻接矩阵根据不同区域之间的乘客需求相似性定义,ε为控制矩阵稀疏性的阈值:
使用皮尔逊相似性计算区域间的乘客需求相似度,计算两个区域i,j历史时间0~t的需求相似度公式如下:
长期编码器和短期编码器:
多步预测中,通常把上一时刻的输出当作下一时刻的输入,这会导致错误的累积,加速模型的崩溃。本文提出同时利用依赖与长期编码和短期编码来达到多步预测,并未使用RNN结构。
多步预测类似于NLP中的文本生成任务,如机器翻译或文本摘要,大多基于Sequence-to-Sequence框架,使用前一步输出当作后一步输入,会导致文本生成不准确性,因此利用注意力机制的上下文向量对输入输出进行相似度计算从而进行对齐,每步输出的上下文向量都不相同。
整个模型由长期编码器、短期编码器、基于注意力机制的输出模块组成:
长期编码器的输入为h步历史数据:的3D的立方体,h为时间步,N为节点个数,为节点特征维度。其中,长期编码器由多个GCCM模块组成,其中每个GGCM捕获所有N个区域之间的空间相关性和k(斑块大小,超参数)时间步长之间的时间相关性。长期编码器总共需要迭代步来捕获历史h步数的时间关联,
长期编码器的输出为一个的矩阵
短期编码器用于为多步预测集成已经预测的需求,使用一个大小为q的滑动窗口来捕获时空相关性,q与h类似,其输入为 3D tensor,输出为的矩阵长短期编码器的区别仅在与编码的历史数据的长度。
模型的重点就是GGCM模块了,长短期编码器都使用了这一模块来捕获时间关联,类似RNN中的Cell。
GGCM使用图卷积提取时空特征,在输入的 3D tensor先paading一层 ,(padding是为了做图卷积),输入变成的3D tensor,
GCN的堆叠,类似于CNN的堆叠,越高层的GCN感受域越大,输入的最上两层为padding
图卷积如下,,表示对邻接矩阵按行求和对角化的度矩阵,∈为reshape后的需求矩阵
使用了重新设计的门控来建模非线性,下面公式中左边为线性变换,右边为门,控制哪些信息流入下一层,中间为点积:
GGCM输出的的维度为 。
基于注意力的输出模块:
将长期和短期编码器的输出进行拼接,得到一个维度为的3D tensor,然后用channel wise(通道注意力机制)进行重要时间戳的提取,提高预测的准确度。
实验
在三个数据集上进行实验,结果优于基线
去除各个组件的消融实验,证明了各个组件的有效性:
总结
基于门控的长短期编码器的GGCM模块,堆叠GGCM可以减少时间的迭代,捕获时空关联;输出模块将长短期编码器的输出进行contanacate,其chanel wise的attention设计可以捕获对预测结果更相关的时间戳输入。总体感觉用堆叠的门控机制可以代替RNN,IJCAI 2018的STGCN中GLU也是类似的设计,IJCAI 2019 Graph Wave net那篇的GTCN也是通过堆叠门控CNN使用扩散图卷积捕获了时间上的关联。这种基于门控的堆叠GCN或CNN的设计比RNN更有效,一定程度防止了梯度消失和预测过程中错误的累计,其训练速度也更快。