Python实现K-Means++聚类算法
声明:代码的运行环境为Python3。Python3与Python2在一些细节上会有所不同,希望广大读者注意。本博客以代码为主,代码中会有详细的注释。相关文章将会发布在我的个人博客专栏《Python从入门到深度学习》,欢迎大家关注~
之前我写过一篇文章叫《Python实现K-Means聚类算法》,这篇文章主要是在之前的基础上介绍K-Means算法的改进版——K-Means++聚类算法。
通过之前K-Means算法的实现过程,我们可以发现K-Means算法的聚类中心 需要提前指定好的,这一点对于数据是有局限性的;再者,如果初始聚类中心选择不好,对聚类结果将会影响很大。为了解决K-Means算法带来的问题,K-Means++算法被提出了。K-Means++算法可以在聚类中心选择的过程中选择较优的聚类中心。
一、K-Means++算法原理及步骤
K-Means++算法在初始化聚类中心时的基本原则是使聚类中心之间的相互距离尽可能的远,其初始过程如下:
1、在数据集中随机选择一个样本作为第一个初始化聚类中心;
2、计算样本中每一个样本点与已经初始化的聚类中心的距离,并选择其中最短的距离;
3、以概率选择距离最大的点作为新的聚类中心;
4、重复2、3步直至选出k个聚类中心;
5、对k个聚类中心使用K-Means算法计算最终的聚类结果。
通过上述的步骤可知,K-Means++算法和K-Means算法最本质的区别在与聚类中心的初始化,K-Means++算法聚类中心初始化的构建过程如下:
def get_cent(points, k):
'''
kmeans++的初始化聚类中心的方法
:param points: 样本
:param k: 聚类中心的个数
:return: 初始化后的聚类中心
'''
m, n = np.shape(points)
cluster_centers = np.mat(np.zeros((k, n)))
# 1、随机选择一个样本点作为第一个聚类中心
index = np.random.randint(0, m)
cluster_centers[0, ] = np.copy(points[index, ]) # 复制函数,修改cluster_centers,不会影响points
# 2、初始化一个距离序列
d = [0.0 for _ in range(m)]
for i in range(1, k):
sum_all = 0
for j in range(m):
# 3、对每一个样本找到最近的聚类中心点
d[j] = nearest(points[j, ], cluster_centers[0:i, ])
# 4、将所有的最短距离相加
sum_all += d[j]
# 5、取得sum_all之间的随机值
sum_all *= random()
# 6、获得距离最远的样本点作为聚类中心点
for j, di in enumerate(d): # enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同事列出数据和数据下标一般用在for循环中
sum_all -= di
if sum_all > 0:
continue
cluster_centers[i] = np.copy(points[j, ])
break
return cluster_centers
在上述代码中,函数nearest()用于计算最短距离,具体实现过程如下所示:
def nearest(point, cluster_centers):
'''
计算point和cluster_centers之间的最小距离
:param point: 当前的样本点
:param cluster_centers: 当前已经初始化的聚类中心
:return: 返回point与当前聚类中心的最短距离
'''
min_dist = FLOAT_MAX
m = np.shape(cluster_centers)[0] # 当前已经初始化聚类中心的个数
for i in range(m):
# 计算point与每个聚类中心之间的距离
d = distance(point, cluster_centers[i, ])
# 选择最短距离
if min_dist > d:
min_dist = d
return min_dist
二、K-Means++算法举例
数据集以及加载数据集、保存聚类结果等相关方法还是使用K-Means算法中的方法,可以直接import使用。
1、调用K-Means算法
调用K-Means算法:
if __name__ == "__main__":
k = 4 # 聚类中心的个数
file_path = "tfidf.txt"
subCenter, centroids = kmeans(load_data(file_path), k, get_cent(load_data(file_path), k))
save_result("result/kmeans_sub", subCenter)
save_result("result/kmeans_center", centroids)
2、结果展示
得到的聚类结果如图所示: