前两天刷到了Google的一篇论文《Step-size Adaptation Using Exponentiated Gradient Updates》,在其中学到了一些新的概念,所以在此记录分享一下。主要的内容有两个,一是非负优化的指数梯度下降,二是基于元学习思想的学习率调整算法,两者都颇有意思,有兴趣的读者也可以了解一下。

指数梯度下降 #

梯度下降大家可能听说得多了,指的是对于无约束函数L(θ)的最小化,我们用如下格式进行更新:
θt+1=θtηθL(θt)
其中η是学习率。然而很多任务并非总是无约束的,对于最简单的非负约束,我们可以改为如下格式更新:
θt+1=θtexp(ηθL(θt))
这里的是逐位对应相乘(Hadamard积)。容易看到,只要初始化的θ0是非负的,那么在整个更新过程中θt都会保持非负,这就是用于非负约束优化的“指数梯度下降”。

怎么理解这个“指数梯度下降”呢?也不难,转化为无约束的情形进行推导就行了。如果θ是非负的,那么φ=logθ就是可正可负的了,因此可以设θ=eφ转化为关于φ的无约束优化问题,继而就可以用梯度下降解决:
φt+1=φtηφL(eφt)=φtηeφteφL(eφt)
我们认为梯度的eφt这部分只起到了调节学习率的作用,所以它不是本质重要的,我们将它舍去得到
φt+1=φtηeφL(eφt)
两边取指数得
eφt+1=eφtexp(ηeφL(eφt))
换回θ=eφ就得到式(2)

元学习调学习率 #

对于元学习(Meta Learning),可能多数读者都跟笔者一样听得多,但几乎没接触过。简单来说,普通机器学习跟元学习的关系,就像是数学中“函数”跟“泛函”的关系,泛函是“函数的函数”,元学习则是“学习如何学习(Learning How to Learn)”,也就是说它是关于“学习”本身的方法论,比如接下来要介绍的,就是“用梯度下降去调整梯度下降”。

我们从一般的梯度下降出发,记目标函数L的梯度为g,那么更新公式为
θt+1=θtηgt
我们希望给每个分量都调节一下学习率,所以我们引入跟参数一样大小的非负变量ν,修改更新公式为
θt+1=θtηνt+1gt
那么,ν要按照什么规则迭代呢?记住我们最终的目的是最小化L,所以ν的更新规则应该也要是梯度下降,而这里ν要求是非负的,所以我们用指数梯度下降:
νt+1=νtexp(γνtL)
注意L本来只是θ的函数,但根据(7),在t时刻我们有θt=θt1ηνtgt1,所以根据链式法则有
νtL=ηgt1θtL=ηgt1gt
代入到ν的更新公式(8),得到
νt+1=νtexp(γηgt1gt)
γη合成一个参数γ,于是整个模型的更新公式是:
νt+1=νtexp(γgt1gt)θt+1=θtηνt+1gt
如果ν初始化为全1,那么将有
νt+1=exp(γtk=1gk1gk)
可以看到,该方法的学习率调节思路是:如果某分量相邻两步的梯度经常同号,那么对应项的累加结果就是正的,意味着我们可以适当扩大一下学习率;如果相邻两步的梯度经常异号,那么对应项的累加结果很可能是负的,意味着我们可以适当缩小一下学习率。

注意这跟Adam调学习率的思想是不一样的,Adam调节学习率的思想是如果某个分量的梯度长时间很小,那么就意味着该参数可能没学好,所以尝试放大它的学习率。两者也算是各有各的道理吧。

简单做个小结 #

本文主要对“指数梯度下降”和“元学习调学习率”两个概念做了简单笔记,“指数梯度下降”是非负约束优化的一个简单有效的方案,而“元学习调学习率”则是元学习的一个简单易懂的应用。其中在介绍“元学习调学习率”时笔者做了一些简化,相比原论文的形式更为简单一些,但思想是一致的。

转载到请包括本文地址:https://spaces.ac.cn/archives/8968

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Mar. 03, 2022). 《指数梯度下降 + 元学习 = 自适应学习率 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/8968

@online{kexuefm-8968,
        title={指数梯度下降 + 元学习 = 自适应学习率},
        author={苏剑林},
        year={2022},
        month={Mar},
        url={\url{https://spaces.ac.cn/archives/8968}},
}