文本分类(4)——朴素贝叶斯
朴素贝叶斯
ef N_Bayes(testVector,method): #第二个表示你用哪个字典吧
for cate,corpus in testVector.items():
result={}
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()), '>', cate)
for file_name,text in corpus.items():
tempProb=1
prob=0
for word,p in text.items():#p相当于每个词在这篇文章中的权重
if word in new_weigthDict[cate]:
tempProb*=new_weigthDict[cate][word][method]*p
if tempProb*docLen[cate]>prob:
prob=tempProb
category=cate
result.setdefault(file_name,category)
fname='C:/lyr/DM/result/'+method+cate+'.json'
with open(fname,'w') as fp:
json.dump(result,fp)
N_Bayes(testVector,'tfidf')
结果
这是我重新做的,跑结果和之前差了好几天,不知道那里有问题,不太科学。。(不太想看了:( )
cate=['Auto','business','edu','ent','healthy','mil','policy','sports','tech','tourism','women']
allResults={}
def analysis(path):
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
for i,category in enumerate(cate):
r=[]
fullPath=path+'tfidf'+category+'.json'
with open(fullPath,'r',encoding='utf-8') as f:
results=json.load(f)
print(len(results))
for name,result in results.items():
r.append(result)
resultCounter=Counter(r)
allResults.setdefault(category,resultCounter)
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
混淆矩阵
赠送混淆矩阵的画法
import matplotlib.pyplot as plt
labels=['A', 'B', 'C', 'F', 'G', 'H', 'I', 'J', 'K', 'L']
tick_marks=np.array(range(len(labels)))+0.5
def plot_confusion_matrix(cm,title='Confusion Matrix',cmap=plt.cm.binary):
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
xlocations=np.array(range(len(labels)))
plt.xticks(xlocations,labels)
plt.yticks(xlocations,labels)
plt.ylabel('True')
plt.xlabel('Predicted')
cm=metrics.confusion_matrix(test_set.label, predicted)
np.set_printoptions(precision=2)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 8), dpi=100)
ind_array = np.arange(len(labels))
x,y=np.meshgrid(ind_array,ind_array)
for x_val,y_val in zip(x.flatten(),y.flatten()):
c = cm_normalized[y_val][x_val]
if c>0.01:
plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=7, va='center', ha='center')
plt.gca().set_xticks(tick_marks,minor=True)
plt.gca().set_yticks(tick_marks,minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
plt.show()