文本分类(4)——朴素贝叶斯

朴素贝叶斯

ef N_Bayes(testVector,method): #第二个表示你用哪个字典吧
    for cate,corpus in testVector.items():
        result={}
        print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), '>', cate)
        for file_name,text in corpus.items():
            tempProb=1
            prob=0
            for word,p in text.items():#p相当于每个词在这篇文章中的权重
                if word in new_weigthDict[cate]:
                    tempProb*=new_weigthDict[cate][word][method]*p
            if tempProb*docLen[cate]>prob:
                prob=tempProb
                category=cate
            result.setdefault(file_name,category)
        fname='C:/lyr/DM/result/'+method+cate+'.json'
        with open(fname,'w') as fp:
            json.dump(result,fp)

N_Bayes(testVector,'tfidf')

结果

这是我重新做的,跑结果和之前差了好几天,不知道那里有问题,不太科学。。(不太想看了:( )

cate=['Auto','business','edu','ent','healthy','mil','policy','sports','tech','tourism','women']
allResults={}
def analysis(path):
    print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
    for i,category in enumerate(cate):
        r=[]
        fullPath=path+'tfidf'+category+'.json'
        with open(fullPath,'r',encoding='utf-8') as f:
            results=json.load(f)
        print(len(results))
        for name,result in results.items():
            r.append(result)
        resultCounter=Counter(r)
        allResults.setdefault(category,resultCounter)
    print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))

文本分类(4)——朴素贝叶斯

混淆矩阵

赠送混淆矩阵的画法

import matplotlib.pyplot as plt
labels=['A', 'B', 'C', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
tick_marks=np.array(range(len(labels)))+0.5
def plot_confusion_matrix(cm,title='Confusion Matrix',cmap=plt.cm.binary):
    plt.imshow(cm,interpolation='nearest',cmap=cmap)
    plt.title(title)
    plt.colorbar()
    xlocations=np.array(range(len(labels)))
    plt.xticks(xlocations,labels)
    plt.yticks(xlocations,labels)
    plt.ylabel('True')
    plt.xlabel('Predicted')
    
cm=metrics.confusion_matrix(test_set.label, predicted)
np.set_printoptions(precision=2)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 8), dpi=100)
ind_array = np.arange(len(labels))
x,y=np.meshgrid(ind_array,ind_array)
for x_val,y_val in zip(x.flatten(),y.flatten()):
    c = cm_normalized[y_val][x_val]
    if c>0.01:
        plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
plt.gca().set_xticks(tick_marks,minor=True)
plt.gca().set_yticks(tick_marks,minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
plt.show()