MultiLabelSoftMarginLoss
1、MultiLabelSoftMarginLoss原理
MultiLabelSoftMarginLoss针对multi-label one-versus-all(多分类,且每个样本只能属于一个类)的情形。
loss的计算公式如下:
其中,x是模型预测的标签,x的shape是(N,C),N表示batch size,C是分类数;y是真实标签,shape也是(N,C),。
的值域是(0,
);
的值域是(1,
);
的值域是(0,1);
的值域是(-
,0),函数曲线如图1所示:
为了看得更清楚一点,再画一下[-10,10]范围内的曲线,如图2:
当y[i]=1得时候,x[i]越大==》越大==》loss越小(因为
前面有个负号);
的函数曲线如图3所示:
[-10,10]范围内的曲线,如图4所示:
当y[i]=0得时候,x[i]越小==》越大==》loss越小(因为
前面有个负号);
2、使用MultiLabelSoftMarginLoss进行图片多分类
2.1 数据源以及如何打标签
以mnist数据源为例,共有10个分类。mnist中每张图片的标签是0-9中的一个数字,我们需要对标签进行转换:对于标签是0的图片,转换成的新标签是[1,0,0,0,0,0,0,0,0,0];对于标签是1的图片,转换成的新标签是[0,1,0,0,0,0,0,0,0,0];依此类推。
2.2 模型训练
2.2.1 模型搭建
MultiLabelSoftMarginLoss层前一层输出的特征图的大小必须是1*10,我们可以使用一个out_features是10的Linear层(全连接层)来实现。