自组织神经网络SOM原理——结合例子MATLAB实现
本文主要内容为SOM神经网络原理的介绍,并结合实例给出相应的MATLAB代码实现,方便初学者接触学习,本人才疏学浅,如有纰漏,还望各路大神积极指点。
一、SOM神经网络介绍
自组织映射神经网络, 即Self Organizing Maps (SOM), 可以对数据进行无监督学习聚类。它的思想很简单,本质上是一种只有输入层--隐藏层的神经网络。隐藏层中的一个节点代表一个需要聚成的类。训练时采用“竞争学习”的方式,每个输入的样例在隐藏层中找到一个和它最匹配的节点,称为它的**节点,也叫“winning neuron”。 紧接着用随机梯度下降法更新**节点的参数。同时,和**节点临近的点也根据它们距离**节点的远近而适当地更新参数。
所以,SOM的一个特点是,隐藏层的节点是有拓扑关系的。这个拓扑关系需要我们确定,如果想要一维的模型,那么隐藏节点依次连成一条线;如果想要二维的拓扑关系,那么就行成一个平面,如下图所示(也叫Kohonen Network):
既然隐藏层是有拓扑关系的,所以我们也可以说,SOM可以把任意维度的输入离散化到一维或者二维(更高维度的不常见)的离散空间上。 Computation layer里面的节点与Input layer的节点是全连接的。
拓扑关系确定后,开始计算过程,大体分成几个部分:
1) 初始化:每个节点随机初始化自己的参数。每个节点的参数个数与Input的维度相同。
2)对于每一个输入数据,找到与它最相配的节点。假设输入时D维的, 即 X={x_i, i=1,...,D},那么判别函数可以为欧几里得距离:
3) 找到**节点I(x)之后,我们也希望更新和它临近的节点。令S_ij表示节点i和j之间的距离,对于I(x)临近的节点,分配给它们一个更新权重:
简单地说,临近的节点根据距离的远近,更新程度要打折扣。
4)接着就是更新节点的参数了。按照梯度下降法更新:
迭代,直到收敛。
二、问题描述
用26个英文字母作为SOM输入样本。每个字符对应一个5维向量,各字符与向量的关系如表4-2所示。由表4-2可以看出,代表A、B、C、D、E的各向量中有4个分量相同,即,因此,A、B、C、D、E应归为一类;代表F、G、H、I、J的向量中有3个分量相同,同理也应归为一类;依此类推。这样就可以由表4-2中输入向量的相似关系,将对应的字符标在图4-8所示的树形结构图中。用SOM网络对其他进行聚类分析。
三、MATLAB代码实现
SOM_mian.m
- %%% 神经网络之自组织网络SOM练习
- %%%作者:xd.wp
- %%%时间:2016.10.02 19:16
- %% 程序说明:
- %%% 1、本程序中,输出层为二维平面,
- %%% 2、几何邻域确定及调整权值采用exp(-distant^2/delta^2)函数
- %%% 3、样本维数为5,输出层结点为70
- %%% 4、输入数据,归一化为单位向量
- clear all;
- clc;
- %% 网络初始化及相应参数初始化
- %加载数据并归一化
- [train_data,train_label]=SOM_data_process();
- data_num=size(train_data,2);
- %权值初始化
- % weight_temp=ones(5,70)/1000;
- weight_temp=rand(5,70)/1000;
- %结点个数
- node_num=size(weight_temp,2);
- %权值归一化
- for i=1:node_num
- weight(:,i)=weight_temp(:,i)/max(weight_temp(:,i));
- end
- %邻域函数参数
- delta=2;
- %调整步幅
- alpha=0.6;
- %% Kohonen算法学习过程
- for t=4:-1:1 %%总体迭代次数
- index_active=ones(1,node_num); %%结点活跃标志
- for n=1:data_num %%每个样本的输入
- % 竞争部分,根据最小距离确定获胜神经元
- [j_min]=SOM_compare(weight,train_data(:,n),node_num,index_active);
- %去**,确保数据结点1对1映射
- index_active(1,j_min)=0;
- %为后续绘图部分服务
- index_plot(1,n)=j_min;
- [x,y]=line_to_array(j_min);
- fprintf('坐标[%d,%d]处为字符%s \n',x,y,train_label(1,n));
- % 学习部分网络权值调整
- st=num2str(t-1);
- switch st
- case '3'
- [weight]=SOM_neighb3(weight,train_data(:,n),j_min,delta,alpha);
- case '2'
- [weight]=SOM_neighb2(weight,train_data(:,n),j_min,delta,alpha);
- case '1'
- [weight]=SOM_neighb1(weight,train_data(:,n),j_min,delta,alpha);
- otherwise
- [weight]=SOM_neighb0(weight,train_data(:,n),j_min,alpha);
- end
- end
- end
- %% 绘制结点分布图像
- figure(1);
- for n=1:data_num
- [x,y]=line_to_array(index_plot(1,n));
- axis([0,12,0,12]);
- text(x,y,'*');
- text(x+0.2,y+0.2,train_label(1,n));
- hold on;
- end
- function [train_data,train_label]=SOM_data_process()
- train_data=[1 0 0 0 0;
- 2 0 0 0 0;
- 3 0 0 0 0;
- 4 0 0 0 0;
- 5 0 0 0 0;
- 3 1 0 0 0;
- 3 2 0 0 0;
- 3 3 0 0 0;
- 3 4 0 0 0;
- 3 5 0 0 0;
- 3 3 1 0 0;
- 3 3 2 0 0;
- 3 3 3 0 0;
- 3 3 4 0 0;
- 3 3 5 0 0;
- 3 3 3 1 0;
- 3 3 3 2 0;
- 3 3 3 3 0;
- 3 3 3 4 0;
- 3 3 3 5 0;
- 3 3 3 3 1;
- 3 3 3 3 2;
- 3 3 3 3 3;
- 3 3 3 3 4;
- 3 3 3 3 5;
- 3 3 3 3 6];
- train_label=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','1','2','3','4','5','6'];
- train_data=train_data';
- length=size(train_data,2);
- for i=1:length
- train_data(:,i)=train_data(:,i)/sqrt(sum(train_data(:,i).*train_data(:,i)));
- % train_data(:,i)=train_data(:,i)/max(train_data(:,i));
- end
- end
SOM_compare.m
- function [j_min]=SOM_compare(weight,train_data_active,node_num,index_active)
- for j=1:node_num
- distant(j,1)=sum((weight(:,j)-train_data_active).^2);
- end
- [~,j_min]=min(distant);
- while(index_active(1,j_min)==0)
- distant(j_min,1)=10000000;
- [~,j_min]=min(distant);
- end
- end
SOM_neighb3.m
- function [weight]=SOM_neighb3(weight,train_data_active,j_min,delta,alpha)
- %% 权值调整幅度分布
- % -0.2
- % 0.2
- % 0.6
- % -0.2 0.2 0.6 1 0.6 0.2 -0.2
- % 0.6
- % 0.2
- % -0.2
- % 单位距离转化比例为0.4
- %% 坐标转换
- [x,y]=line_to_array(j_min);
- % 将1*70向量中的坐标转化为7*10矩阵中的坐标
- % 1 8 ···
- % 7 14 ···
- %% 权值调整过程
- %结点靠上边情况
- if (x<=3)
- for m=1:1:x+3
- if (y<=3) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=8) %结点靠右边
- for n=y-3:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-3:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点靠下边情况
- elseif (x>=5)
- for m=x-3:1:7
- if (y<=3) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=8) %结点靠右边
- for n=y-3:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-3:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点正好在中间
- else
- for m=1:7
- if (y<=3) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=8) %结点靠右边
- for n=y-3:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-3:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- end
- end
SOM_neighb2.m
- function [weight]=SOM_neighb2(weight,train_data_active,j_min,delta,alpha)
- %% 权值调整幅度分布
- % -0.2
- % 0.2
- % 0.6
- % -0.2 0.2 0.6 1 0.6 0.2 -0.2
- % 0.6
- % 0.2
- % -0.2
- % 单位距离转化比例为0.4
- %% 坐标转换
- [x,y]=line_to_array(j_min);
- % 将1*70向量中的坐标转化为7*10矩阵中的坐标
- % 1 8 ···
- % 7 14 ···
- %% 权值调整过程
- %结点靠上边情况
- if (x<=2)
- for m=1:1:x+2
- if (y<=2) %结点靠左边
- for n=1:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=9) %结点靠右边
- for n=y-2:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-2:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点靠下边情况
- elseif (x>=6)
- for m=x-2:1:7
- if (y<=2) %结点靠左边
- for n=1:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=9) %结点靠右边
- for n=y-2:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-2:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点正好在中间
- else
- for m=x-2:1:x+2
- if (y<=2) %结点靠左边
- for n=1:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=9) %结点靠右边
- for n=y-2:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-2:1:y+2
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- end
- end
SOM_neighb1.m
- function [weight]=SOM_neighb1(weight,train_data_active,j_min,delta,alpha)
- %% 权值调整幅度分布
- % -0.2
- % 0.2
- % 0.6
- % -0.2 0.2 0.6 1 0.6 0.2 -0.2
- % 0.6
- % 0.2
- % -0.2
- % 单位距离转化比例为0.4
- %% 坐标转换
- [x,y]=line_to_array(j_min);
- % 将1*70向量中的坐标转化为7*10矩阵中的坐标
- % 1 8 ···
- % 7 14 ···
- %% 权值调整过程
- %结点靠上边情况
- if (x<=1)
- for m=1:1:x+1
- if (y<=1) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=10) %结点靠右边
- for n=y-1:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-1:1:y+1
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点靠下边情况
- elseif (x>=7)
- for m=x-3:1:7
- if (y<=1) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=10) %结点靠右边
- for n=y-1:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-1:1:y+1
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- %结点正好在中间
- else
- for m=x-1:1:x+1
- if (y<=1) %结点靠左边
- for n=1:1:y+3
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- elseif (y>=10) %结点靠右边
- for n=y-1:1:10
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- else
- for n=y-1:1:y+1
- distant=sqrt((x-m)^2+(y-n)^2);
- weight(:,(n-1)*7+m)=weight(:,(n-1)*7+m)-alpha*exp(-distant^2/delta^2)*(weight(:,(n-1)*7+m)-train_data_active);
- end
- end
- end
- end
- end
SOM_neighb0.m
- function [weight]=SOM_neighb0(weight,train_data_active,j_min,alpha)
- weight(:,j_min)=weight(:,j_min)+alpha*(weight(:,j_min)-train_data_active);
- end
line_to_array.m
- function [x,y]=line_to_array(j_min)
- % 将1*70向量中的坐标转化为7*10矩阵中的坐标
- % 1 8 ···
- % 7 14 ···
- y=ceil(j_min/7);
- x=rem(j_min,7);
- end
四、结果显示
不同初始条件的结果图