简单的Mahout分类示例

问题描述:

我想培训mahout进行分类。对我而言,这段文字来自数据库,我真的不想将它们存储到mahout培训文件中。我检查了MIA源代码并更改了以下代码以进行非常基本的培训任务。 mahout示例中的常见问题是要么显示如何使用20个新闻组的cmd提示符使用mahout,要么代码对Hadoop Zookeeper等有很多依赖关系。如果有人可以查看我的代码或点我给了一个非常简单的教程,它展示了如何训练模型然后使用它。简单的Mahout分类示例

截至目前在下面的代码我永远不会越过if (best != null),因为learningAlgorithm.getBest();总是返回null!

对不起张贴整个代码,但没有看到任何其他选项

public class Classifier { 

    private static final int FEATURES = 10000; 
    private static final TextValueEncoder encoder = new TextValueEncoder("body"); 
    private static final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); 
    private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"}; 

    /** 
    * @param args the command line arguments 
    */ 
    public static void main(String[] args) throws Exception { 
     int leakType = 0; 
     // TODO code application logic here 
     AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES, new L1()); 
     Dictionary newsGroups = new Dictionary(); 
     //ModelDissector md = new ModelDissector(); 
     ListMultimap<String, String> noteBySection = LinkedListMultimap.create(); 
     noteBySection.put("good", "I love this product, the screen is a pleasure to work with and is a great choice for any business"); 
     noteBySection.put("good", "What a product!! Really amazing clarity and works pretty well"); 
     noteBySection.put("good", "This product has good battery life and is a little bit heavy but I like it"); 

     noteBySection.put("bad", "I am really bored with the same UI, this is their 5th version(or fourth or sixth, who knows) and it looks just like the first one"); 
     noteBySection.put("bad", "The phone is bulky and useless"); 
     noteBySection.put("bad", "I wish i had never bought this laptop. It died in the first year and now i am not able to return it"); 


     encoder.setProbes(2); 
     double step = 0; 
     int[] bumps = {1, 2, 5}; 
     double averageCorrect = 0; 
     double averageLL = 0; 
     int k = 0; 
     //------------------------------------- 
     //notes.keySet() 
     for (String key : noteBySection.keySet()) { 
      System.out.println(key); 
      List<String> notes = noteBySection.get(key); 
      for (Iterator<String> it = notes.iterator(); it.hasNext();) { 
       String note = it.next(); 


       int actual = newsGroups.intern(key); 
       Vector v = encodeFeatureVector(note); 
       learningAlgorithm.train(actual, v); 

       k++; 
       int bump = bumps[(int) Math.floor(step) % bumps.length]; 
       int scale = (int) Math.pow(10, Math.floor(step/bumps.length)); 
       State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); 
       double maxBeta; 
       double nonZeros; 
       double positive; 
       double norm; 

       double lambda = 0; 
       double mu = 0; 
       if (best != null) { 
        CrossFoldLearner state = best.getPayload().getLearner(); 
        averageCorrect = state.percentCorrect(); 
        averageLL = state.logLikelihood(); 

        OnlineLogisticRegression model = state.getModels().get(0); 
        // finish off pending regularization 
        model.close(); 

        Matrix beta = model.getBeta(); 
        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); 
        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { 

         @Override 
         public double apply(double v) { 
          return Math.abs(v) > 1.0e-6 ? 1 : 0; 
         } 
        }); 
        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { 

         @Override 
         public double apply(double v) { 
          return v > 0 ? 1 : 0; 
         } 
        }); 
        norm = beta.aggregate(Functions.PLUS, Functions.ABS); 

        lambda = learningAlgorithm.getBest().getMappedParams()[0]; 
        mu = learningAlgorithm.getBest().getMappedParams()[1]; 
       } else { 
        maxBeta = 0; 
        nonZeros = 0; 
        positive = 0; 
        norm = 0; 
       } 
       System.out.println(k % (bump * scale)); 
       if (k % (bump * scale) == 0) { 

        if (learningAlgorithm.getBest() != null) { 
         System.out.println("----------------------------"); 
         ModelSerializer.writeBinary("c:/tmp/news-group-" + k + ".model", 
           learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); 
        } 

        step += 0.25; 
        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); 
        System.out.printf("%d\t%.3f\t%.2f\t%s\n", 
          k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]); 
       } 
      } 

     } 
     learningAlgorithm.close(); 
    } 

    private static Vector encodeFeatureVector(String text) { 
     encoder.addText(text.toLowerCase()); 
     //System.out.println(encoder.asString(text)); 
     Vector v = new RandomAccessSparseVector(FEATURES); 
     bias.addToVector((byte[]) null, 1, v); 
     encoder.flush(1, v); 
     return v; 
    } 
} 
+0

您是否可以使用建议的修补程序更新原始代码示例 - 这有助于让示例工作。谢谢。 – Eugen 2013-06-21 11:19:20

你需要正确添加单词到你的特征向量。它看起来像下面的代码:

 bias.addToVector((byte[]) null, 1, v); 

没有做你期望的。它只是将空字节添加到权重为1的特征向量中。

您正在调用WordValueEncoder.addToVector(byte[] originalForm, double w, Vector data)方法的包装。

确保循环遍历笔记地图值中的单词值,并相应地将它们添加到特征向量中。

我强烈建议你也将您的问题的非常好的人在亨利马乌邮件列表https://mahout.apache.org/general/mailing-lists,-irc-and-archives.html

今天早些时候发生在我身上。我看到你有很少的初始样本,因为你正在玩的代码像我一样。我的问题是,由于这种算法是一种自适应算法,我需要设置为“适应”是非常低的这样否则将永远无法找到一个新的最佳模型的时间间隔和窗口:

learningAlgorithm.setInterval(1); 
learningAlgorithm.setAveragingWindow(1); 

这样,该算法可能被迫在每一个它看到的矢量后面“适应”,这将是至关重要的,因为你的示例代码只有6个矢量。