tensorflow SVM 线性可分数据分类

引言

对于SVM,具体的可以参考其他博客。我觉得SVM里面的数学知识不好懂,特别是拉格朗日乘子法和后续的 FTT 条件。一个简单的从整体上先把握的方法就是:不要管怎么来的,知道后面是一个二次优化就行了。
此处还是用 iris 数据集的萼片长度和花瓣宽度来对鸢尾花分类。
所用的损失函数是
tensorflow SVM 线性可分数据分类
其中 n 是每次训练的数据量,就是下面代码中的 batch_size ,A、b是要优化的变量,A是系数,b是偏差。yi 是理论输出(-1 或 1)。a(阿尔法,打不出)是权重,自己设,至于怎么设就根据经验了。

结果展示

tensorflow SVM 线性可分数据分类

代码

# 导入库
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
from sklearn import datasets

sess = tf.Session()

# 设置随机种子,代码结束后有讨论
np.random.seed(7)
tf.set_random_seed(8)

iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])

# 随机取90%的数据为训练集,剩下的为测试集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.9), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test  = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test  = y_vals[test_indices]

batch_size = 100
trian_times = 500
learning_rate = 0.01

# placeholder 就是装训练时放数据的容器
x_data = tf.placeholder(dtype=tf.float32, shape=[None, 2])
y_target = tf.placeholder(dtype=tf.float32, shape=[None, 1])

# 变量Variable 是要优化的变量,要是不能变肯定不能优化了
A = tf.Variable(tf.random_normal(shape=[2,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

module_output = tf.subtract(tf.matmul(x_data, A), b)

# 损失函数,此处用的是:看下文的损失函数
l2_norm = tf.reduce_sum(tf.square(A))
alpha = tf.constant([0.01])
classification_term = tf.reduce_mean(tf.maximum(0., tf.subtract(1., tf.multiply(module_output, y_target))))
loss = tf.add(classification_term, tf.multiply(alpha, l2_norm))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

# 计算精度
prediction = tf.sign(module_output)
accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, y_target), tf.float32))

init = tf.global_variables_initializer()
sess.run(init)

loss_vec = []
train_accuracy = []
test_accuracy = []

# 训练在 for 循环里面
for i in range(trian_times):
	index_rand = np.random.choice(len(x_vals_test), size=batch_size)
	x_rand = x_vals_train[index_rand]
	y_rand = np.transpose([y_vals_train[index_rand]])
	sess.run(optimizer, feed_dict={x_data:x_rand, y_target:y_rand})
	loss_vec.append(sess.run(loss, feed_dict={x_data:x_rand, y_target:y_rand}))
	train_accuracy.append(sess.run(accuracy, feed_dict={x_data:x_vals_train, y_target:np.transpose([y_vals_train])}))
	test_accuracy.append(sess.run(accuracy, feed_dict={x_data:x_vals_test, y_target:np.transpose([y_vals_test])}))

[[a1], [a2]] = sess.run(A)
[[bb]] = sess.run(b)

# 准备画分割线
slope = -a2/a1
intercept = bb/a1
x_line = [x[1] for x in x_vals]
y_line = [slope*x+intercept for x in x_line]

# 准备画数据点
setosa_x = [d[1] for i,d in enumerate(x_vals) if y_vals[i]==1]
setosa_y = [d[0] for i,d in enumerate(x_vals) if y_vals[i]==1]
not_setosa_x = [d[1] for i,d in enumerate(x_vals) if y_vals[i]==-1]
not_setosa_y = [d[0] for i,d in enumerate(x_vals) if y_vals[i]==-1]

# 画数据点
plt.subplot(221)
plt.plot(setosa_x, setosa_y, 'ro', label='Is setosa')
plt.plot(not_setosa_x, not_setosa_y, 'g*', label='Non-setosa')
plt.plot(x_line, y_line, 'b-', label='Linear Seperator')
plt.ylim([2, 10])
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.legend(loc='lower right')

# 画精度变化曲线
plt.subplot(222)
plt.plot(train_accuracy, 'g-', label='Training Accuracy')
plt.plot(test_accuracy, 'r--', label='Test Accuracy')
plt.title('Train and Test Set Accuracies')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')

# 画 Loss 曲线
plt.subplot(223)
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')

plt.show()

讨论

上面有一个随机树的设定,制定了产生随机数的种子(seed),如果你运行我的代码,虽然中间有产生随机数的过程,但产生的所有随机数都我运行时产生的随机数一样,最终得到的分割线的斜率等也和我的一模一样。
但是,让人迷惑的是,换用一些其他的随机数种子,即把上面的第 10、11行写成

np.random.seed(14)
tf.set_random_seed(222)

这是产生的图如下
tensorflow SVM 线性可分数据分类
一看就知道分类效果很差。
换句话说,分类效果和随机数有关,要是不初始化随机数,可成产生很糟的结果。在刚开始我没有写第10、11行时,出现过不少完全没有分为两类的情况。我也不清楚什么原因。欢迎留言讨论。

参考书籍

Nick McClure. TensorFlow机器学习攻略(影印版)[M]. 东南大学出版社(南京).2017.10