tensorflows十五 再探Momentum和Nesterov's accelerated gradient descent 利用自动控制PID概念引入误差微分控制超参数改进NAGD,速度快波动小
神经网络BP-GD算法和自动控制PID算法有类似之处,都是利用误差反馈对问题进行求解,不同的是自动控制调节的是系统的输入,神经网络调节的是系统本身。本文将引入误差微分控制超参数kd_damp对NAGD算法进行优化,收敛的速度更快!波动更小!
自动控制PID算法与神经网络Momentum算法比较:
http://www.sohu.com/a/242354509_297288
http://www.360doc.com/content/16/1010/08/36492363_597225745.shtml
https://blog.****.net/tsyccnh/article/details/76673073
NAGD算法在Momentum上考虑了误差微分项,和PID算法更为接近。
Momentum法公式:
注意,上式中alpha移到梯度函数g()前,结果是一样的,网上存在这两种形式的公式!
NAGD公式一:
NAGD公式二:
注意,上式中alpha移到梯度函数g()前并去掉g()中的alpha,就是NAGD公式一的形式,结果是一样的!
NAGD公式三:
all_loss = []
all_step = []
last_a = a
last_b = b
va = 0
vb = 0
gamma = 0.9
####增加误差微分项控制,当kd_damp=0误差微分项不起作用,相当于Momentum::
kd_damp=0.0
####
for step in range(1,100):
loss = 0
all_da = 0
all_db = 0
a_ahead = a - gamma*va*kd_damp
b_ahead = b - gamma*vb*kd_damp
#-- 求loss
for i in range(0,len(x)):
y_p = a_ahead*x[i] + b_ahead
loss = loss + (y[i] - y_p)*(y[i] - y_p)/2
all_da = all_da + da(y[i],y_p,x[i])
all_db = all_db + db(y[i],y_p)
loss = loss/len(x)
……….
last_a = a
last_b = b
###
#-- 参数更新
# print('a = %.3f,b = %.3f' % (a,b))
va = gamma * va+ rate*all_da
vb = gamma * vb+ rate*all_db
a = a - va
b = b - vb
#--
对NAGD引入新的超参数kd_damp用于控制误差微分项,当采用标准的NAGD算法时:kd_damp=1.0,结果如下图:
当kd_damp=0误差微分项不起作用,相当于Momentum:
当kd_damp=2.0增强误差微分项的作用,优化收敛的速度更快,波动更小!
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 本代码是一个最简单的线形回归问题,优化函数为 momentum with Nesterov
rate = 0.01 # learning rate
def da(y,y_p,x):
return (y-y_p)*(-x)
def db(y,y_p):
return (y-y_p)*(-1)
def calc_loss(a,b,x,y):
tmp = y - (a * x + b)
tmp = tmp ** 2 # 对矩阵内的每一个元素平方
SSE = sum(tmp) / (2 * len(x))
return SSE
def draw_hill(x,y):
a = np.linspace(-20,20,100)
print(a)
b = np.linspace(-20,20,100)
x = np.array(x)
y = np.array(y)
allSSE = np.zeros(shape=(len(a),len(b)))
for ai in range(0,len(a)):
for bi in range(0,len(b)):
a0 = a[ai]
b0 = b[bi]
SSE = calc_loss(a=a0,b=b0,x=x,y=y)
allSSE[ai][bi] = SSE
a,b = np.meshgrid(a, b)
return [a,b,allSSE]
# 模拟数据
x = [30 ,35,37, 59, 70, 76, 88, 100]
y = [1100, 1423, 1377, 1800, 2304, 2588, 3495, 4839]
# 数据归一化
x_max = max(x)
x_min = min(x)
y_max = max(y)
y_min = min(y)
for i in range(0,len(x)):
x[i] = (x[i] - x_min)/(x_max - x_min)
y[i] = (y[i] - y_min)/(y_max - y_min)
[ha,hb,hallSSE] = draw_hill(x,y)
hallSSE = hallSSE.T# 重要,将所有的losses做一个转置。原因是矩阵是以左上角至右下角顺序排列元素,而绘图是以左下角为原点。
# 初始化a,b值
a = 10.0
b = -20.0
fig = plt.figure(1, figsize=(12, 8))
fig.suptitle('learning rate: %.2f method: Nesterov momentum'%(rate), fontsize=15)
# 绘制图1的曲面
ax = fig.add_subplot(2, 2, 1, projection='3d')
ax.set_top_view()
ax.plot_surface(ha, hb, hallSSE, rstride=2, cstride=2, cmap='rainbow')
# 绘制图2的等高线图
plt.subplot(2,2,2)
ta = np.linspace(-20, 20, 100)
tb = np.linspace(-20, 20, 100)
plt.contourf(ha,hb,hallSSE,15,alpha=0.5,cmap=plt.cm.hot)
C = plt.contour(ha,hb,hallSSE,15,colors='black')
plt.clabel(C,inline=True)
plt.xlabel('a')
plt.ylabel('b')
plt.ion() # iteration on
all_loss = []
all_step = []
last_a = a
last_b = b
va = 0
vb = 0
gamma = 0.9
####增加误差微分项控制,当kd_damp=0误差微分项不起作用,相当于Momentum::
kd_damp=2.0
####
for step in range(1,100):
loss = 0
all_da = 0
all_db = 0
a_ahead = a - gamma*va*kd_damp
b_ahead = b - gamma*vb*kd_damp
#-- 求loss
for i in range(0,len(x)):
y_p = a_ahead*x[i] + b_ahead
loss = loss + (y[i] - y_p)*(y[i] - y_p)/2
all_da = all_da + da(y[i],y_p,x[i])
all_db = all_db + db(y[i],y_p)
loss = loss/len(x)
### 绘图区
# 绘制图1中的loss点
ax.scatter(a, b, loss, color='black')
# 绘制图2中的loss点
plt.subplot(2,2,2)
plt.scatter(a,b,s=5,color='blue')
plt.plot([last_a,a],[last_b,b],color='aqua')
# 绘制图3中的回归直线
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.plot(x, y, 'o')
x_ = np.linspace(0, 1, 2)
y_draw = a * x_ + b
plt.plot(x_, y_draw)
# 绘制图4的loss更新曲线
all_loss.append(loss)
all_step.append(step)
plt.subplot(2,2,4)
plt.plot(all_step,all_loss,color='orange')
plt.xlabel("step")
plt.ylabel("loss")
last_a = a
last_b = b
###
#-- 参数更新
# print('a = %.3f,b = %.3f' % (a,b))
va = gamma * va+ rate*all_da
vb = gamma * vb+ rate*all_db
a = a - va
b = b - vb
#--
if step%1 == 0:
print("step: ", step, " loss: ", loss)
plt.show()
plt.pause(0.01)
plt.show()
plt.pause(99999999999)