对于Neural ODE的小研究
上面就是用欧拉方法解常微分方程的代码。
用●Midpoint method (or RK2) - 2nd order method方法只需
这里odeint是一种通用的ODE求解器,必须提供fun(t,ht),初始条件,评估函数的时间步和求解器
像Runge–Kutta(RK4)或Adams–Bashforth这样的高阶方法可以保证更好的数值精度
所有这些都可以在形式通用的接口中实现(例如scipy
将神经网络与ODE求解器集成
结果如图所示
●
We can use existing (and efficient) implementation of solvers to
integrate NNs
dynamics
●
The
memory cost is O(1)
, due to
reversibility
i.e. we don’t need to store all activations in the graph, we can easily recover them by backward integration (i.e. time reversed integration)
●
Complex dynamics can be modeled with fewer parameters
●
We can control
accuracy/speed trade-off
with adaptive solvers by setting
lower/higher error tolerances
●
Hidden states can be accessed at any value of t -
no discrete time steps
as in RestNet skip connection
NeuralODE - adjoint method
●
Adjoint method
can be understand as a continuous version of chain rule
●
Chain rule:
Consider following sequence of operations (
L
is a scalar loss):
●
We can compute gradient of L w.r.t input state using chain rule
此公式是任何深度学习autograd的核心