基于随机游走的personalRank算法
基于随机游走的personalRank算法是从谷歌的pageRank算法演变来的,使用的比较少,可以说是比较小众。至于pageRank算法在此贴出我认为写得好的博客pageRank算法参考
1、personalRank算法介绍:
数据集随机分成训练集和测试集,指定训练集中任意点开始进行随机游走,游走的时候根据不同点之间的权重来选择游走方向的概率,到达下一个点以后会根据指定的alpha值随机决定继续游走还是回到起点重新游走,训练集中所有的点都进行游走结束后,每个点会聚集到一个游走的概率值,根据这个值进行相关性推荐。
网上python实现的很多,本人的python编程能力也一般,所以用java实现
Rank类(相关性排序)
public class Rank implements Comparable<Rank>{
private int twoColumn;
private double sum_simlatrity=0.0;//设置默认值
public int getTwoColumn() {
return twoColumn;
}
public void setTwoColumn(int movie) {
this.twoColumn = movie;
}
public Double getSum_simlatrity() {
return sum_simlatrity;
}
public void setSum_simlatrity(double sumSimlatrity) {
sum_simlatrity = sumSimlatrity;
}
//重写equals方法
public boolean equals(Object obj) {
return (this.getTwoColumn() == ((Rank) obj).getTwoColumn());
}
public int compareTo(Rank o) {
return this.getSum_simlatrity().compareTo(o.getSum_simlatrity());
}
}
对所有推荐结果进行堆排序(数据量偏大的时候堆排序是最快的)
public class Heapsort {
public <T> void sort(List<T> list,int k) {
Object[] a = list.toArray();
sort(a,k);
ListIterator<T> i = list.listIterator();
// for (int j=0; j<a.length; j++) {
// i.next();
// i.set((T)a[j]);
// }
for(int j=a.length-1;j>=0;j--){
i.next();
i.set((T)a[j]);
}
}
public void sort(Object[] a,int k) {
buildMaxHeapify(a);
heapSort(a,k);
}
private void heapSort(Object[] a, int k) {
for(int i=a.length-1;i>=(a.length-k);i--){
Object temp=a[0];
a[0]=a[i];
a[i]=temp;
max_heapify(a,0,i);
}
}
public void buildMaxHeapify(Object[] a){
for(int j=(a.length>>1)-1;j>=0;j--){
max_heapify(a,j,a.length);
}
}
public void max_heapify(Object[] a,int index,int heapSize){
//当前点与左右子节点比较
int left=(index<<1)+1;
int right=(index<<1)+2;
int largest=index;
if(left<heapSize&& ((Comparable<Object>) a[left]).compareTo(a[index])>0){
largest=left;
}
if(right<heapSize&&((Comparable<Object>) a[right]).compareTo(a[largest])>0){
largest=right;
}
//得到最大值后可能需要交换,如果交换了,其子节点可能就不是最大堆了,需要重新调整
if(largest!=index){
//交换
Object object=a[index];
a[index]=a[largest];
a[largest]=object;
max_heapify(a,largest,heapSize);
}
}
/*public static void main(String[] args) {
List<Rank> list = new ArrayList<>();
Rank r1 = new Rank();
r1.setSum_simlatrity(0.7);
Rank r2 = new Rank();
r2.setSum_simlatrity(0.5);
Rank r3 = new Rank();
r3.setSum_simlatrity(0.9);
Rank r4 = new Rank();
r4.setSum_simlatrity(0.1);
Rank r5 = new Rank();
r5.setSum_simlatrity(0.8);
list.add(r1);
list.add(r2);
list.add(r3);
list.add(r4);
list.add(r5);
Heapsort heapsort = new Heapsort();
heapsort.sort(list,5);
System.out.println(list.toString());
}*/
}
算法实现主类
public class PersonalRank {
HashMap<Integer,Set<Integer>> trainset=new HashMap<Integer,Set<Integer>>();
HashMap<Integer,Set<Integer>> testset=new HashMap<Integer,Set<Integer>>();
HashMap<Integer,Set<Integer>> inverse_table=new HashMap<Integer,Set<Integer>>();
HashMap<Integer,Integer> movie_popular=new HashMap<Integer,Integer>();
HashMap<Integer,List<Double>> trainset_weight = new HashMap<>();
HashMap<Integer,List<Double>> inverse_table_weight = new HashMap<>();
int i=0;
int maxOneId = 0;
int maxTwoId = 0;
int trainset_length;
int testset_length;
int two_count=0;//第二列去重后的和(测试数据集中的movie列)
List<Rank> recommendedList=null;
List<Ralated_user> ralatedUsersList=null;
int k=0;
int n=10;
Random random=new Random(0);
public void generate_dataset(int pivot,Map<List<Long>,List<Long>> map) throws IOException{
map.forEach((l1,l2) -> {
List<Long> startTempList = l1;
List<Long> endTempList = l2;
//todo
List<Double> weight = new ArrayList<>();
for (int i = 0; i < startTempList.size(); i++) {
int startInt = Integer.parseInt(startTempList.get(i).toString());
int endInt = Integer.parseInt(endTempList.get(i).toString());
double edgeWeight = weight.get(i);
if (startInt > maxOneId){
maxOneId = startInt;
}
if (endInt > maxTwoId){
maxTwoId = endInt;
}
if(random.nextInt(8)==pivot){
if(testset.containsKey(startInt)){
HashSet<Integer> set =(HashSet<Integer>) testset.get(startInt);
set.add(endInt);
testset.put(startInt,set);
}else{
Set<Integer> set=new HashSet<Integer>();
set.add(endInt);
testset.put(startInt,set);
}
testset_length++;
}else{
if(trainset.containsKey(startInt)){
List<Double> w = trainset_weight.get(startInt);
w.add(edgeWeight);
trainset_weight.put(startInt,w);
HashSet<Integer> set =(HashSet<Integer>) trainset.get(startInt);
set.add(endInt);
trainset.put(startInt,set);
}else{
List<Double> w = new ArrayList<>();
weight.add(edgeWeight);
trainset_weight.put(startInt,w);
Set<Integer> set=new HashSet<Integer>();
set.add(endInt);
trainset.put(startInt,set);
}
trainset_length++;
}
this.i++;
if (this.i %100000 == 0)
System.out.println("已装载"+ this.i +"文件");
}
});
/*map.forEach((l0,l1) -> {
String[] content = new String[2];
content[0] = l0.toString();
content[1] = l1.toString();
int l0Int = Integer.parseInt(content[0]);
int l1Int = Integer.parseInt(content[1]);
if (l0Int > maxUserId){
maxUserId = l0Int;
}
if (l1Int > maxMovieId){
maxMovieId = l1Int;
}
if(random.nextInt(8)==pivot){
if(testset.containsKey(l0Int)){
HashSet<Integer> set =(HashSet<Integer>) testset.get(Integer.parseInt(content[0]));
set.add(l1Int);
testset.put(l0Int,set);
}else{
Set<Integer> set=new HashSet<Integer>();
set.add(l1Int);
testset.put(l0Int,set);
}
testset_length++;
}else{
if(trainset.containsKey(l0Int)){
HashSet<Integer> set =(HashSet<Integer>) trainset.get(Integer.parseInt(content[0]));
set.add(l1Int);
trainset.put(l0Int,set);
}else{
Set<Integer> set=new HashSet<Integer>();
set.add(l1Int);
trainset.put(l0Int,set);
}
trainset_length++;
}
this.i++;
if (this.i %100000 == 0)
System.out.println("已装载"+ this.i +"文件");
});*/
/*File file=new File("/home/lyg/下载/ml-latest-small/newRatings.csv");
if(!file.exists()||file.isDirectory())
throw new FileNotFoundException();
BufferedReader br=new BufferedReader(new FileReader(file));
String temp=null;
while ((temp=br.readLine())!=null) {
String[] content=temp.replaceAll("\n\t", "").split(",");
if (Integer.parseInt(content[0]) > maxUserId){
maxUserId = Integer.parseInt(content[0]);
}
if (Integer.parseInt(content[1]) > maxMovieId){
maxMovieId = Integer.parseInt(content[1]);
}
if(random.nextInt(8)==pivot){
if(testset.containsKey(Integer.parseInt(content[0]))){
HashSet<Integer> set =(HashSet<Integer>) testset.get(Integer.parseInt(content[0]));
set.add(Integer.parseInt(content[1]));
testset.put(Integer.parseInt(content[0]),set);
}else{
Set<Integer> set=new HashSet<Integer>();
set.add(Integer.parseInt(content[1]));
testset.put(Integer.parseInt(content[0]),set);
}
testset_length++;
}else{
if(trainset.containsKey(Integer.parseInt(content[0]))){
HashSet<Integer> set =(HashSet<Integer>) trainset.get(Integer.parseInt(content[0]));
set.add(Integer.parseInt(content[1]));
trainset.put(Integer.parseInt(content[0]),set);
}else{
Set<Integer> set=new HashSet<Integer>();
set.add(Integer.parseInt(content[1]));
trainset.put(Integer.parseInt(content[0]),set);
}
trainset_length++;
}
i++;
if (i%100000 == 0)
System.out.println("已装载"+i+"文件");
}*/
System.out.println("测试集和训练集分割完成,测试集长度:"+testset_length+",训练集长度:"+trainset_length);
}
public void calc_user_sim(){
for(int obj : trainset.keySet()){
Set<Integer> value = trainset.get(obj);
Iterator<Integer> it=value.iterator();
List<Double> weight = trainset_weight.get(obj);
int i = 0;
while(it.hasNext()) {
int o=it.next();
Double aLong = weight.get(i);
if(inverse_table.containsKey(o)){
Set<Integer> set=inverse_table.get(o);
set.add(obj);
inverse_table.put(o,set);
List<Double> list = inverse_table_weight.get(o);
list.add(aLong);
inverse_table_weight.put(o,list);
}else {
Set<Integer> set=new HashSet<Integer>();
set.add(obj);
inverse_table.put(o,set);
List<Double> list = new ArrayList<>();
list.add(aLong);
inverse_table_weight.put(o,list);
}
// count item popularity at the same time
if(!movie_popular.containsKey(o)){
movie_popular.put(o,1);
}else {
movie_popular.put(o,movie_popular.get(o)+1);
}
i++;
}
}
//建立反转表的目的是方便建立co-rated movies 矩阵
two_count=inverse_table.size();
}
double alpah=0.6;
public void personalRank(int root,int max_step){
double rank1[]=new double[maxOneId + 1];
double rank2[]=new double[maxTwoId + 1];
double temp1[]=new double[maxOneId + 1];
double temp2[]=new double[maxTwoId + 1];
rank1[root]=1;//从推荐用户开始游走,所以它的相关性初始值是1
for(int k=0;k<max_step;k++){
for(int i=0;i<temp1.length;i++) temp1[i]=0.0;
for(int j=0;j<temp2.length;j++) temp2[j]=0.0;
Iterator<Integer> u=trainset.keySet().iterator();//u是用户
while(u.hasNext()){
int uu=u.next();
List<Double> list = trainset_weight.get(uu);
Set<Integer> movies=trainset.get(uu);//movies是电影集合
Iterator<Integer> it=movies.iterator();
int i = 0;
while(it.hasNext()){//遍历电影
//用数组下标记录电影编号,值记录相关性
temp2[it.next()]+=alpah*rank1[uu]/(list.get(i) * movies.size());
i++;
}
}
Iterator<Integer> m=inverse_table.keySet().iterator();//m是电影编号
while(m.hasNext()){
int mm=m.next();
List<Double> list = inverse_table_weight.get(mm);
Set<Integer> us=inverse_table.get(mm);//us是用户集合
Iterator<Integer> it=us.iterator();
int i = 0;
while(it.hasNext()){//遍历用户
//用数组下标记录用户编号,值记录相关性
temp1[it.next()]+=alpah*rank2[mm]/(list.get(i) * us.size());
i++;
}
}
temp1[root]+=1-alpah;
for(int i=0;i<temp1.length;i++)
rank1[i]=temp1[i];
for(int i=0;i<temp2.length;i++)
rank2[i]=temp2[i];
/*for (int i = 0; i < rank2.length; i++) {
if (rank2[i] != 0.0){
System.out.println(i + "---" + rank2[i]);
}
}*/
}
recommendedList=new ArrayList<Rank>();
Set<Integer> watched_movies=trainset.get(root);
for(int i=0;i<rank2.length;i++){
if(watched_movies.contains(i)||rank2[i]==0.0)
continue;
Rank r=new Rank();
r.setTwoColumn(i);
r.setSum_simlatrity(rank2[i]);
recommendedList.add(r);
}
// if (recommendedMoviesList.size() != 0){
Heapsort ss=new Heapsort();
if (recommendedList.size() < n){
n = recommendedList.size();
}
ss.sort(recommendedList, n);
// }
}
public Set<Rank> evaluate(int user){
int rec_count=0;
int test_count=0;
int hit=0;
double popularSum=0;
Set<Rank> all_rec=new HashSet<>();
Set<Integer> test_movies=testset.get(user);
//被推荐用户,最大迭代次数
personalRank(user, k);
if(recommendedList.size() !=0){
if(recommendedList.size()<n) n=recommendedList.size();
for(int i=0;i<n;i++){
Rank rank=recommendedList.get(i);
all_rec.add(rank);
}
}
/*double precision=hit/(1.0*rec_count);
double recall=hit/(1.0*test_count);
double coverage=all_rec.size()/(1.0*two_count);
double popularity=popularSum/(1.0*rec_count);
System.out.println("precision=%"+precision*100+"\trecall=%"+recall*100+"\tcoverage=%"+coverage*100+"\tpopularity="+popularity);*/
return all_rec;
}
/**
*
* @param map 数据集
* @param targetNodeId 对照实体
* @param k 迭代次数
* @param n 取前n条进行推荐
*/
public static Set<Rank> go(Map<List<Long>,List<Long>> map,Long targetNodeId,int k,int n){
Set<Rank> all_rec = null;
try {
PersonalRank ss=new PersonalRank();
//将数据分成测试集和训练集
ss.generate_dataset(3,map);
//统计数量并且建立反转表
ss.calc_user_sim();
//遍历所有训练集找到每个用户适合推荐的电影
//while(it.hasNext()){
ss.k=k;
ss.n = n;
all_rec = ss.evaluate(Integer.parseInt(targetNodeId.toString()));
} catch (IOException e) {
e.printStackTrace();
}
return all_rec;
}
/*public static void main(String[] args) throws IOException {
PersonalRank ss=new PersonalRank();
ss.generate_dataset(3);//分割数据集
ss.calc_user_sim();//建立反转表
Set<Integer> set=new HashSet<Integer>();
set.add(5);
set.add(10);
set.add(20);
set.add(40);
set.add(80);
set.add(160);
Iterator<Integer> it=set.iterator();
while(it.hasNext()){
ss.k=it.next();
}
ss.evaluate();
}*/
}
personalRank算法的时间复杂度较高(处理2Mmovielens数据集大约需要50s),不适合大数据量的数据集使用,更加不适合实时推荐项目使用。