多标签“Softmax+交叉熵”的软标签版本
By 苏剑林 | 2022-05-07 | 57941位读者 |(注:本文的相关内容已整理成论文《ZLPR: A Novel Loss for Multi-label Classification》,如需引用可以直接引用英文论文,谢谢。)
在《将“Softmax+交叉熵”推广到多标签分类问题》中,我们提出了一个用于多标签分类的损失函数:
log(1+∑i∈Ωnegesi)+log(1+∑j∈Ωpose−sj)
这个损失函数有着单标签分类中“Softmax+交叉熵”的优点,即便在正负类不平衡的依然能够有效工作。但从这个损失函数的形式我们可以看到,它只适用于“硬标签”,这就意味着label smoothing、mixup等技巧就没法用了。本文则尝试解决这个问题,提出上述损失函数的一个软标签版本。
巧妙联系 #
多标签分类的经典方案就是转化为多个二分类问题,即每个类别用sigmoid函数σ(x)=1/(1+e−x)激活,然后各自用二分类交叉熵损失。当正负类别极其不平衡时,这种做法的表现通常会比较糟糕,而相比之下损失(1)通常是一个更优的选择。
在之前文章的评论区中,读者 @wu.yan 揭示了多个“sigmoid+二分类交叉熵”与式(1)的一个巧妙的联系:多个“sigmoid+二分类交叉熵”可以适当地改写成
−∑j∈Ωposlogσ(sj)−∑i∈Ωneglog(1−σ(si))=log∏j∈Ωpos(1+e−sj)+log∏i∈Ωneg(1+esi)=log(1+∑j∈Ωpose−sj+⋯)+log(1+∑i∈Ωnegesi+⋯)
对比式(1),我们可以发现式(1)正好是上述多个“sigmoid+二分类交叉熵”的损失去掉了⋯所表示的高阶项!在正负类别不平衡时,这些高阶项占据了过高的权重,加剧了不平衡问题,从而效果不佳;相反,去掉这些高阶项后,并没有改变损失函数的作用(希望正类得分大于0、负类得分小于0),同时因为括号内的求和数跟类别数是线性关系,因此正负类各自的损失差距不会太大。
形式猜测 #
这个巧妙联系告诉我们,要寻找式(1)的软标签版本,可以尝试从多个“sigmoid+二分类交叉熵”的软标签版本出发,然后尝试去掉高阶项。所谓软标签,指的是标签不再是0或1,而是0~1之间的任意实数都有可能,表示属于该类的可能性。而对于二分类交叉熵,它的软标签版本很简单:
−tlogσ(s)−(1−t)log(1−σ(s))
这里t就是软标签,而s就是对应的打分。模仿过程(2),我们可以得到
−∑itilogσ(si)−∑i(1−ti)log(1−σ(si))=log∏i(1+e−si)ti+log∏i(1+esi)1−ti=log∏i(1+tie−si+⋯)+log∏i(1+(1−ti)esi+⋯)=log(1+∑itie−si+⋯)+log(1+∑i(1−ti)esi+⋯)
如果去掉高阶项,那么就得到
log(1+∑itie−si)+log(1+∑i(1−ti)esi)
它就是式(1)的软标签版本的候选形式,可以发现当ti∈{0,1}时,正好是退化为式(1)的。
证明结果 #
就目前来说,式(5)顶多是一个“候选”形式,要将它“转正”,我们需要证明在ti为0~1浮点数时,式(5)能学出有意义的结果。所谓有意义,指的是理论上能够通过si来重构ti的信息(si是模型预测结果,ti是给定标签,所以si能重构ti是机器学习的目标)。
为此,我们记式(5)为l,并求si的偏导数:
∂l∂si=−tie−si1+∑itie−si+(1−ti)esi1+∑i(1−ti)esi
我们知道l的最小值出现在所有∂l∂si都等于0时,直接去解方程组∂l∂si=0并不容易,但笔者留意到一个神奇的“巧合”:当tie−si=(1−ti)esi时,每个∂l∂si自动地等于0!所以tie−si=(1−ti)esi应该就是l的最优解了,解得
ti=11+e−2si=σ(2si)
这是一个很漂亮的结果,它告诉我们几个信息:
1、式(5)确实是式(1)合理的软标签推广,它能通过si完全重建ti的信息,其形式也刚好与sigmoid相关;
2、如果我们要将结果输出为0~1的概率值,那么正确的做法应该是σ(2si)而不是直觉中的σ(si);
3、既然最后的概率公式也具有sigmoid的形式,那么反过来想,也可以理解为我们依旧还是在学习多个sigmoid激活的二分类问题,只不过损失函数换成了式(5)。
实现技巧 #
式(5)的实现可以参考bert4keras的代码multilabel_categorical_crossentropy,其中有个小细节值得跟大家一起交流一下。
首先,我们将式(5)可以等价地改写成
log(1+∑ie−si+logti)+log(1+∑iesi+log(1−ti))
所以看上去,只需要将logti加到−si、将log(1−ti)加到si上,补零后做常规的logsumexp即可。但实际上,ti是有可能取到0或1的,对应的logti或log(1−ti)就是负无穷,而框架无法直接处理负无穷,因此通常在log之前需要clip一下,即选定ϵ>0后定义
clip(t)={ϵ,t<ϵt,ϵ≤t≤1−ϵ1−ϵ,t>1−ϵ
但这样一clip,问题就来了。由于ϵ不是真的无穷小,比如ϵ=10−7,那么logϵ大约是−16左右;而像GlobalPointer这样的场景中,我们会提前把不合理的si给mask掉,方式是将对应的si置为一个绝对值很大的负数,比如−107;然而我们再看式(8),第一项的求和对象是e−si+logti,所以−107就会变成107,如果ti没有clip,那么理论上logti是log0=−∞,可以把−si+logti重新变回负无穷,但刚才我们已经看到进行了clip之后的logti顶多就是−16,远远比不上−si的107,所以−si+logti依然是一个大正数。
为了解决这个问题,我们不止要对ti进行clip,我们还要找出原本小于ϵ的ti,手动将对应的−si置为绝对值很大的负数,同样还要找出大于1−ϵ的ti,将对应的si置为绝对值很大的负数,这样做就是将小于ϵ的按照绝对等于0额外处理,将大于1−ϵ的按照绝对等于1处理。
文章小结 #
本文主要将笔者之前提出的多标签“Softmax+交叉熵”推广到软标签场景,有了对应的软标签版本后,我们就可以将它与label smoothing、mixup等技巧结合起来了,像GlobalPointer等又可以多一个炼丹方向。
转载到请包括本文地址:https://spaces.ac.cn/archives/9064
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 07, 2022). 《多标签“Softmax+交叉熵”的软标签版本 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9064
@online{kexuefm-9064,
title={多标签“Softmax+交叉熵”的软标签版本},
author={苏剑林},
year={2022},
month={May},
url={\url{https://spaces.ac.cn/archives/9064}},
}
May 15th, 2022
感谢分享,请问苏神这个时候 网络预测阈值该如何选择?因为真实标签是[0,1]区间的值这个时候正负例如何构建?
既然是软标签,就没有正负样本的说法。阈值是人为自己选的,反正最后输出的概率是σ(2σi)。
May 17th, 2022
感谢苏神分享,第3个公示的第一项前面是否应该有个负号呢
对的,感谢指出,已经修正。
May 19th, 2022
我推导了一下,(7)确实是(6)的唯一解。话说(6)中分母求和下标也用i稍微有点不严谨,不过问题不大。
我觉得从运算的先后顺利来看,分母求和下标也用i并不会造成理解上的混乱。换用一个下标只不过是主流的一种强迫症而已。
May 31st, 2022
感谢苏神分享,小白想问下是不是如果数据集只有硬标签,这个软标签版本就用不上呢?
硬标签也可能用到mixup、标签平滑等技巧,这时候就转化为软标签了。
June 14th, 2022
请问式5优化的目的是不是使sigmoid(2*y_pred)=y_true?
但是我随机生成20个y_pred(实数范围),20个y_true(0到1范围)实验了下,发现并不是满足上式时,式5值最小。
补充一下:
我发现用单对y_pred和y_true试验时,就满足越接近式7时式5就越小,但用一组数据去试验就不满足这个关系。
理论解是这样,不清楚你实验的具体设置。
July 13th, 2022
请教下onehot转化为软标签后,loss下降到一定程度就不变了,可能是什么原因呢
可能是显然成立的原因(谁说loss的最小值一定是0?)
July 13th, 2022
K.infinity() 苏神,这个值在keras 2.3.1中没有啊,请问你是在哪个版本上跑的,话说这个值有pytorch替代版本吗?
1、安装并导入最新版本bert4keras;
2、在我这里不会有任何模型的pytorch实现,pytorch请自便。
July 21st, 2022
苏神,多标签之间的数据量相差很大会有问题吗,比如三个类别分别为35万,7万,3千
May 2nd, 2024
苏神,想问一下为什么二分类的交叉熵的软标签损失版本可以仍然用式(3)来表示,感觉明显不合理呀,有更好的表示方式吗
为什么不合理?它本质上就是两个分布的KL散度。