RandomForestClassifier并不适用于所有类别
问题描述:
clf = RandomForestClassifier(min_samples_leaf=20)
clf.fit(X_train, y)
prob_pos= clf.predict_proba(X_test)
尺寸预测概率:RandomForestClassifier并不适用于所有类别
(Pdb) print X_train.shape,X_test.shape,y.shape
(1422392L, 14L) (233081L, 14L) (1422392L, 6L)
输出:
(Pdb) prob_pos
[array([[ 0.96133658, 0.03866342],
[ 0.93514554, 0.06485446],
[ 0.91520408, 0.08479592],
...,
[ 0.95826389, 0.04173611],
[ 0.97130832, 0.02869168],
[ 0.93223876, 0.06776124]]), array([[ 0.9907225 , 0.0092775 ],
[ 0.94489664, 0.05510336],
[ 0.98428571, 0.01571429],
...,
[ 0.96415476, 0.03584524],
[ 0.99193939, 0.00806061],
[ 0.98918919, 0.01081081]]), array([[ 0.9907225 , 0.0092775 ],
[ 0.98253968, 0.01746032],
[ 0.98166667, 0.01833333],
...,
[ 0.96415476, 0.03584524],
[ 0.99444444, 0.00555556],
[ 0.99004914, 0.00995086]]), array([[ 1. , 0. ],
[ 0.99642857, 0.00357143],
[ 0.98082011, 0.01917989],
...,
[ 0.96978897, 0.03021103],
[ 0.97467974, 0.02532026],
[ 1. , 0. ]]), array([[ 1. , 0. ],
[ 1. , 0. ],
[ 0.98238095, 0.01761905],
...,
[ 1. , 0. ],
[ 0.99661017, 0.00338983],
[ 0.99428571, 0.00571429]]), array([[ 1. , 0. ],
[ 1. , 0. ],
[ 0.99285714, 0.00714286],
...,
[ 0.99705882, 0.00294118],
[ 0.97885167, 0.02114833],
[ 0.98688312, 0.01311688]])]
我不明白,为什么概率看起来不是X-train_samples×6?
答
由于y.shape
是(1422392L,6L),因此您有6个不同的输出。因此,您有6个数组的列表作为概率输出。由于每个数组都有2列,因此我得出结论:每个输出都有2个类。确实有两类吗?然后,一切看起来都很好。
如果6个类是一个热编码类似[1,0,0,0,0,0]
,这对于6个输出是有效的2类。然后,列表中的第一个数组为第一个输出的“0”和“1”概率,第二个数组为第二个输出的“0”和“1”概率,依此类推。
您正在根据here in scikit-learn documentaion所述实际解决多输出问题,请参见“1.10.3。多输出问题”。
获得6个类的概率的最简单方法是将您的类编码为1,2,3,4,5,6并获得1列的y
。然后你会得到一个6列的数组作为概率
如果你有两个类有时,如[1,0,1,0,0,1]
,那么你的问题本质上是多输出(在我的评论中说'多级'这是一个错位)。要获得6个类的概率,您需要收集列表中每个数组的第二列。该代码是
prob_nx6 = np.array([arr[:,1] for arr in prob_pos]).T
,现在我是编辑这个答案,我想出了一个简单的代码
prob_nx6 = np.hstack(prob_pos)[:,1::2]
这会给你形状的二维数组(N,6)(N = 1422392在你的情况)。如果希望每个长度为6的n个阵列的列表,简码是
prob_nx6_liofarr = list(np.hstack(prob_pos)[:,1::2])
如果此列表内的每个元素必须是列表,而不是阵列(即列表的列表)时,代码是
prob_nx6_liofli = np.hstack(prob_pos)[:,1::2].tolist()
不,我只有6个班。 –
那么你的6个输出是什么?你的课是一个热点编码吗? – lanenok
6个输出是类。是的,二进制'[1,0,0,0,0,0]' –