CAN:借助先验分布提升分类性能的简单后处理技巧
By 苏剑林 | 2021-10-22 | 165287位读者 |顾名思义,本文将会介绍一种用于分类问题的后处理技巧——CAN(Classification with Alternating Normalization),出自论文《When in Doubt: Improving Classification Performance with Alternating Normalization》。经过笔者的实测,CAN确实多数情况下能提升多分类问题的效果,而且几乎没有增加预测成本,因为它仅仅是对预测结果的简单重新归一化操作。
有趣的是,其实CAN的思想是非常朴素的,朴素到每个人在生活中都应该用过同样的思想。然而,CAN的论文却没有很好地说清楚这个思想,只是纯粹形式化地介绍和实验这个方法。本文的分享中,将会尽量将算法思想介绍清楚。
思想例子 #
假设有一个二分类问题,模型对于输入a给出的预测结果是p(a)=[0.05,0.95],那么我们就可以给出预测类别为1;接下来,对于输入b,模型给出的预测结果是p(b)=[0.5,0.5],这时候处于最不确定的状态,我们也不知道输出哪个类别好。
但是,假如我告诉你:1、类别必然是0或1其中之一;2、两个类别的出现概率各为0.5。在这两点先验信息之下,由于前一个样本预测结果为1,那么基于朴素的均匀思想,我们是否更倾向于将后一个样本预测为0,以得到一个满足第二点先验的预测结果?
这样的例子还有很多,比如做10道选择题,前9道你都比较有信心,第10题完全不会只能瞎蒙,然后你一看发现前9题选A、B、C的都有就是没有一个选D的,那么第10题在蒙的时候你会不会更倾向于选D?
这些简单例子的背后,有着跟CAN同样的思想,它其实就是用先验分布来校正低置信度的预测结果,使得新的预测结果的分布更接近先验分布。
不确定性 #
准确来说,CAN是针对低置信度预测结果的后处理手段,所以我们首先要有一个衡量预测结果不确定性的指标。常见的度量是“熵”(参考《“熵”不起:从熵、最大熵原理到最大熵模型(一)》),对于p=[p1,p2,⋯,pm],定义为:
H(p)=−m∑i=1pilogpi
然而,虽然熵是一个常见选择,但其实它得出的结果并不总是符合我们的直观理解。比如对于p(a)=[0.5,0.25,0.25]和p(b)=[0.5,0.5,0],直接套用公式得到H(p(a))>H(p(b)),但就我们的分类场景而言,显然我们会认为p(b)比p(a)更不确定,所以直接用熵还不够合理。
一个简单的修正是只用前top-k个概率值来算熵,不失一般性,假设p1,p2,⋯,pk是概率最高的k个值,那么
Htop-k(p)=−k∑i=1˜pilog˜pi
其中˜pi=pi/k∑i=1pi。为了得到一个0~1范围内的结果,我们取Htop-k(p)/logk为最终的不确定性指标。
算法步骤 #
现在假设我们有N个样本需要预测类别,模型直接的预测结果是N个概率分布p(1),p(2),⋯,p(N),假设测试样本和训练样本是同分布的,那么完美的预测结果应该有:
1NN∑i=1p(i)=˜p
其中˜p是类别的先验分布,我们可以直接从训练集估计。也就是说,全体预测结果应该跟先验分布是一致的,但受限于模型性能等原因,实际的预测结果可能明显偏离上式,这时候我们就可以人为修正这部分。
具体来说,我们选定一个阈值τ,将指标小于τ的预测结果视为高置信度的,而大于等于τ的则是低置信度的,不失一般性,我们假设前n个结果p(1),p(2),⋯,p(n)属于高置信度的,而剩下的N−n个属于低置信度的。我们认为高置信度部分是更加可靠的,所以它们不用修正,并且可以用它们来作为“标准参考系”来修正低置信度部分。
具体来说,对于∀j∈{n+1,n+2,⋯,N},我们将p(j)与高置信度的p(1),p(2),⋯,p(n)一起,执行一次“行间”标准化:
p(k)←p(k)/ˉpטp,ˉp=1n+1(p(j)+n∑i=1p(i))
这里的k∈{1,2,⋯,n}∪{j},其中乘除法都是element-wise的。不难发现,这个标准化的目的是使得所有新的p(k)的平均向量等于先验分布˜p,也就是促使式(3)的成立。然而,这样标准化之后,每个p(k)就未必满足归一化了,所以我们还要执行一次“行内”标准化:
p(k)←p(k)im∑i=1p(k)i
但这样一来,式(3)可能又不成立了。所以理论上我们可以交替迭代执行这两步,直到结果收敛(不过实验结果显示一般情况下一次的效果是最好的)。最后,我们只保留最新的p(j)作为原来第j个样本的预测结果,其余的p(k)均弃之不用。
注意,这个过程需要我们遍历每个低置信度结果j∈{n+1,n+2,⋯,N}执行,也就是说是逐个样本进行修正,而不是一次性修正的,每个p(j)都借助原始的高置信度结果p(1),p(2),⋯,p(n)组合来按照上述步骤迭代,虽然迭代过程中对应的p(1),p(2),⋯,p(n)都会随之更新,但那只是临时结果,最后都是弃之不用的,每次修正都是用原始的p(1),p(2),⋯,p(n)。
参考实现 #
这是笔者给出的参考实现代码:
# 预测结果,计算修正前准确率
y_pred = model.predict(
valid_generator.fortest(), steps=len(valid_generator), verbose=True
)
y_true = np.array([d[1] for d in valid_data])
acc_original = np.mean([y_pred.argmax(1) == y_true])
print('original acc: %s' % acc_original)
# 评价每个预测结果的不确定性
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty = -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)
# 选择阈值,划分高、低置信度两部分
threshold = 0.9
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]
# 显示两部分各自的准确率
# 一般而言,高置信度集准确率会远高于低置信度的
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean()
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean()
print('confident acc: %s' % acc_confident)
print('unconfident acc: %s' % acc_unconfident)
# 从训练集统计先验分布
prior = np.zeros(num_classes)
for d in train_data:
prior[d[1]] += 1.
prior /= prior.sum()
# 逐个修改低置信度样本,并重新评价准确率
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):
Y = np.concatenate([y_pred_confident, y[None]], axis=0)
for j in range(iters):
Y = Y**alpha
Y /= Y.mean(axis=0, keepdims=True)
Y *= prior[None]
Y /= Y.sum(axis=1, keepdims=True)
y = Y[-1]
if y.argmax() == y_true_unconfident[i]:
right += 1
# 输出修正后的准确率
acc_final = (acc_confident * len(y_pred_confident) + right) / len(y_pred)
print('new unconfident acc: %s' % (right / (i + 1.)))
print('final acc: %s' % acc_final)
实验结果 #
那么,这样的简单后处理,究竟能带来多大的提升呢?原论文给出的实验结果是相当可观的:
笔者也在CLUE上的两个中文文本分类任务上做了实验,显示基本也有点提升,但没那么可观(验证集结果):
IFLYTEK(类别数:119)TNEWS(类别数:15)BERT60.06%56.80%BERT + CAN60.52%56.86%RoBERTa60.64%58.06%RoBERTa + CAN60.95%58.00%
大体上来说,类别数目越多,效果提升越明显,如果类别数目比较少,那么可能提升比较微弱甚至会下降(当然就算下降也是微弱的),所以这算是一个“几乎免费的午餐”了。超参数选择方面,上面给出的中文结果,只迭代了1次,k的选择为3、τ的选择为0.9,经过简单的调试,发现这基本上已经是比较优的参数组合了。
还有的读者可能想问前面说的“高置信度那部分结果更可靠”这个情况是否真的成立?至少在笔者的两个中文实验上它是明显成立的,比如IFLYTEK任务,筛选出来的高置信度集准确率为0.63+,而低置信度集的准确率只有0.22+;TNEWS任务类似,高置信度集准确率为0.58+,而低置信度集的准确率只有0.23+。
个人评价 #
最后再来综合地思考和评价一下CAN。
首先,一个很自然的疑问是为什么不直接将所有低置信度结果跟高置信度结果拼在一起进行修正,而是要逐个进行修正?笔者不知道原论文作者有没有对比过,但笔者确实实验过这个想法,结果是批量修正有时跟逐个修正持平,但有时也会下降。其实也可以理解,CAN本意应该是借助先验分布,结合高置信度结果来修正低置信度的,在这个过程中,如果掺入越多的低置信度结果,那么最终的偏差可能就越大,因此理论上逐个修正会比批量修正更为可靠。
说到原论文,读过CAN论文的读者,应该能发现本文介绍与CAN原论文大致有三点不同:
1、不确定性指标的计算方法不同。按照原论文的描述,它最终的不确定性指标计算方式应该是
−1logmk∑i=1pilogpi
也就是说,它也是top-k个概率算熵的形式,但是它没有对这k个概率值重新归一化,并且它将其压缩到0~1之间的因子是logm而不是logk(因为它没有重新归一化,所以只有除logm才能保证0~1之间)。经过笔者测试,原论文的这种方式计算出来的结果通常明显小于1,这不利于我们对阈值的感知和调试。
2、对CAN的介绍方式不同。原论文是纯粹数学化、矩阵化地陈述CAN的算法步骤,而且没有介绍算法的思想来源,这对理解CAN是相当不友好的。如果读者没有自行深入思考算法原理,是很难理解为什么这样的后处理手段就能提升分类效果的,而在彻底弄懂之后则会有一种故弄玄虚之感。
3、CAN的算法流程略有不同。原论文在迭代过程中还引入了参数α,使得式(4)变为
p(k)←[p(k)]α/ˉpטp,ˉp=1n+1([p(j)]α+n∑i=1[p(i)]α)
也就是对每个结果进行α次方后再迭代。当然,原论文也没有对此进行解释,而在笔者看来,该参数纯粹是为了调参而引入的(参数多了,总能把效果调到有所提升),没有太多实际意义。而且笔者自己在实验中发现,α=1基本已经是最优选择了,精调α也很难获得是实质收益。
文章小结 #
本文介绍了一种名为CAN的简单后处理技巧,它借助先验分布来将预测结果重新归一化,几乎没有增加多少计算成本就能提高分类性能。经过笔者的实验,CAN确实能给分类效果带来一定提升,并且通常来说类别数越多,效果越明显。
转载到请包括本文地址:https://spaces.ac.cn/archives/8728
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 22, 2021). 《CAN:借助先验分布提升分类性能的简单后处理技巧 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/8728
@online{kexuefm-8728,
title={CAN:借助先验分布提升分类性能的简单后处理技巧},
author={苏剑林},
year={2021},
month={Oct},
url={\url{https://spaces.ac.cn/archives/8728}},
}
October 22nd, 2021
刷榜应该很有用,看起来结果像是会退化成低置信度的结果按照高置信度的结果进行一波瞎蒙。线上任务应该还是要卡个阈值丢掉这些低置信的结果,再设定一个default的结果比较好。
线上的话直接可以回答“不知道”、“我现在还无法回答这个问题”之类的。
October 22nd, 2021
很欣赏这类工作。 原理简单,方法有效。 实验和观点一致。 相比之下,我的工作魔改网络,强行解释,不知道高到哪里去了。 这样的工作看似没有什么很惊艳的地方,但实际上把一个简单的方法应用上去还需要很多工程上的tricks。不论是工业应用还是比赛刷分,这个方法都能发挥其作用。真的不错!
整体来看,你是在赞美这个方法的,但老兄的文字表达还需要加强呀(捂脸~)。比如“相比之下,我的工作魔改网络,强行解释,不知道高到哪里去了”,这句话的字面意思,是表达“我的工作很好”~
苏学长对论文的解读比原论文还要强呀
October 25th, 2021
"用先验分布来校正低置信度的预测结果,使得新的预测结果的分布更接近先验分布"总结的很精辟。
October 26th, 2021
看文章中的公式(4),这个迭代公式的作用就是将每个类别的概率根据先验分布和当前计算的平均概率分布值的比值,对相应的类别概率进行拉升或压缩。能不能直接通过高置信度的样本集,把每个类别对应的概率的拉升压缩比直接算出来。这样对每个需要检测的低置信测试样本,直接将概率位乘上相应的拉升压缩比,然后归一化。这样后处理的话,感觉比较简单。也不知道效果如何。
我理解你的意思,相当于算ˉp的时候只用高置信度样本的,不用新增的那个低置信度样本吧?在迭代一次的情况下,如果高置信度样本比较多,那么事实上几乎没差别。
October 27th, 2021
我只想说苏神举的做题的例子真的太恰当了!
我只想说苏神举的做题的例子真的太恰当了!想想看,为啥要取topk,就像有四个选项的选择题,总有一两个选项能够被确定pass的,可能最后就是卡在了选 A 还是选 B 的情况下了。
October 27th, 2021
这是赌徒谬误的一个反映吧,因为一般假设样本之间是独立的,这个把它弄得相关了起来。对达到一定批量下的样本预测结果可能有效。
这不是赌徒谬误,而是合理的概率现象。
我们可以先看三个题目:
1、如果一对夫妇有一男一女两个孩子,其中一个是男孩,那么另一个是女孩的概率是多少?
2、如果一对夫妇有两个孩子,其中一个是男孩,那么另一个是女孩的概率是多少?
3、如果一对夫妇有两个孩子,其中老大是男孩,那么老二是女孩的概率是多少?
第1题显然是100%;第2题是比较经典的条件概率题,答案是2/3;第3题自然就是50%。
所谓赌徒谬误,指的是在第3题的条件下,相信答案大于50%,就是所谓的“第一胎生了女孩,第二胎总该是男孩了吧”的思想。那么第2问为什么会得出大于50%的结果呢?它跟第3题的区别是我们虽然知道一个是男孩,但我们不知道是第几个,因此是乱序的,包含“男男、男女、女男”三种可能性,另一个为女的可能性更大,它没有先后顺序,而是在事情已经发生完毕的情况下,猜测结果的可能性。
那么,CAN对应哪一种情况呢?表面上看是第1种或第3种,事实上是第2种。比如做10道选择题,虽然题目都有编号,虽然题目通常按照从易到难排序,但本质上这些题目是无序的,或者说题目的顺序与答案的选择没有相关关系(不会出现更难的题更倾向于选D的情况)。这样一来,它就是上述三个例子中的第2种情形。那么,假如前9题都没有一道选D,那么基于正常的答案分布均匀假设,第10题选D的概率确实是大于选A、B、C的(跟第2题的答案2/3大于50%同理)。当然,不至于大到100%,但超过25%是必然的。
0、第二个问题确实保证女孩的概率大于50%,但这不是CAN的原理
1、当利用置信度进行筛选的时候,其实就已经确定了作为基准的高置信度样本的索引了,也就是知道了“已知的那个男孩到底是老大还是老二”了,所以这依旧是类似于第三个题目而不是第二个题目:“知道高置信度的样本是哪些”近似于“知道男孩是哪个”,而不那么近似于“知道男孩至少有多少个”
2、“赌徒谬误”之所以是“谬误”而不是“完全错误”的原因,是因为现实赌博达不到使它正确的一个条件:即大样本,但是在所谓“大数据时代”,达到这个要求不是小意思吗?所以CAN这个方法还是能用的。
3、CAN是在做跟“赌徒谬误”类似的事情,但是它满足了估计对象是“大样本”的条件:一堆高置信度样本+一个低置信度样本都是它的估计对象,只不过修正的只有一个低置信度样本而已,而单纯的“赌徒谬误”其实只估计一个或极少量样本
4、其实我开这楼的表述也不准确,与其说CAN是“赌徒谬误”的反映,不如说它是“大数定律”的反映,而这也就是CAN的原理,仅此而已
总结:CAN的原理就是简单的大数定律,既不是什么“赌徒谬误”,也不是所谓“无序”的条件概率
“当利用置信度进行筛选的时候,其实就已经确定了作为基准的高置信度样本的索引了”这句话,最重要的区别在于我究竟有没有去看低置信度样本。比如两个孩子,我随机挑一个看,发现他是男的,但另一个我并没有去看,所以我只能得出至少有一个是男的;如果我去看了另一个孩子,但分辨不出是男是女,这又不一样。
所以,这个应该是我们对模型的低置信度样本的理解不一样。你理解的模型低置信度样本,应该相当于我去看了另一个孩子,但无法分辨男女,这时候是一个困难样本;而我理解的低置信度样本,是模型“短路”,相当于我去看了另一个孩子,但刚好那时候我视力有问题,相当于没看成。怎么说呢,可能二者有之吧,因为根据很多实验结果,模型认为的困难样本,其实也并不是真的困难样本,所以模型认为它低置信度,有时候是真的困难,有时候真的可能只是随机故障了~
你用大数定律来解释也可以,大数定律是指整体结果会趋于平均状态,即A、B、C、D均匀分布;而我的观点是A、B、C、D均匀分布是最大概率状态,所以在前面有很多A、B、C时,最后一题选D属于最大概率选择。
跟苏神讨论问题十分有意思,能够引人理清思路深入考虑问题。
其实我一直疑问的点在于,一般化的样本的独立性假设怎么会在CAN中不成立?
CAN到底采用了什么原理引入了这种相关性?
显然,大数定律在样本独立的情况下也是成立的,所以我在楼上就采用了这个解释。
不过通过与苏神的讨论,现在我可以得出一个更加“模型化”的解释了。
传统频率派的条件下,CAN事实上是不成立的:
模型的数据生成过程:给定分布参数 θ0,样本 xi 独立同分布于 p(xi|θ0)。
xi 与 xj 是没什么关系的,也就不存在用 xi 的分布去修正 xj 的分布这种说法了。
然而,Bayes 学派却可以提供一种满足要求的条件,这也是为什么苏神会提到“先验”的原因吧(虽然没有显性提 Bayes)。
Bayes 派的数据生成过程:给定先验分布参数 α,数据分布参数变量 θ|α∼p(θ|α),在分类的情况下是狄利克雷分布 θ|α∼Dir(α),然后样本 xi 在给定 θ 的条件下独立同分布于 xi|θ∼p(x|θ)。
注意,此处的 θ 为变量,且是隐变量,这意味着 xi 与 xj 是相关的!
先验参数 α 就可以通过训练集进行估计了,EM 算法也好、对 θ 积分后直接得到 p(x|α) 进行最大似然估计也好。
现在表达相关性的模型有了,那么修正的过程是怎样的呢?
实际上这是一个引入平均场估计的想法!
类似于变分推断中的平均场,只不过这里是直接经验公式修正罢了。
对于待估计样本集合 {xi}Ni=1,其中 {xi}N−1i=1 为高置信度样本,分布为 q(xi)(i=1,⋯,N−1),另一个为低置信度样本 xN∼q(xN),CAN希望 p(x1,x2,⋯xN−1,xN|α)≈∏N−1i=1q(xi)q(xN),其中 q(xi)(i=1,⋯,N−1) 固定,而修正 q(xN) 使得等式近似。
整个思路就是这样了,Bayes YYDS!
在分类的情况下,样本分布可以是 xi|θ∼Cat(θ)
其实我们的理解越来越接近了,就是对如何生成各题答案的流程的理解。
在先验分布约束下,我们认为答案是这样生成的:简单起见假设只有4道题,每道题的答案为A、B、C、D之一,那么出题者先从均匀分布中选4个答案,然后再将答案打乱分配给4道题,这种情况下,组合{A,B,C,D}的概率是最大的,所以前3道题如果选了A、B、C,那么第4道选D的概率是最大的。
也就是说,先验分布是促进这一切的“幕后黑手”。
用贝叶斯那一套的话来说的话,其实你使用的近似解是平均场(Mean Field),宏观体现就是大数定律;我潜意识里用的解是最大后验估计(MAP),所以我一直说最大概率的情况。而在均匀分布下,这两个恰好相等~
哈哈,很高兴能和苏神一起把这件事情整理通透~
October 28th, 2021
使用这种方式的前提,首先得保证数据集是完全随机的,如果样本顺序不随机,先验概率分布就不置信了吧?小白,请教一下
假设的就是训练集和验证集同分布。
November 1st, 2021
多谢对这篇文章的介绍和解读。尤其是那个做题的例子 真的特别恰当!
我们是从pragmatic的工作 (Rational Speech Act)中获得的灵感,并且沿用了RSA的hyperparameters。所以和你的出发点有点不同。所以可能让你觉得”故弄玄虚“。其实在写文章的时候也是前后修改了好多次,纠结于要不要说清楚RSA和他的关系。最后决定把和RSA的关系放到讨论的部分。
另外对于不确定性指标,我们的确是把这个值压缩到0到1之间,主要考虑是为了在不同数据集上做实验是能统一标准和对比,毕竟如果不normalize的话,这个指标和number of classes有关。
欢迎作者大驾光临~
1、“故弄玄虚”确实是理解之后的第一感觉,但没有对作者不敬之意,抱歉;
2、@李子涵|comment-17645同学的讨论,也值得参考,欢迎阅读;
3、对于不确定性指标,我指的是你们设计的指标(如果我没理解错的话),通常会过度压缩(明显小于1)。
November 5th, 2021
苏神,我想问一下这个用高置信度的去纠正低置信度的这里,高置信度的样本和低置信度的样本是指一个batch里的吗,还是指整个测试集的里所有的样本啊
比较实用的方法是:高置信度样本使用验证集的全体高置信度样本,然后固定下来不变;每次预测的时候,先判断预测结果的置信度,如果属于高置信度,直接返回,否则按照公式修正。
November 5th, 2021
感谢大佬分享,花了一天时间调通到我的代码里,不过可能是我改的有问题。 分掉的很厉害。