变分自编码器(四):一步到位的聚类方案
By 苏剑林 | 2018-09-17 | 347588位读者 |由于VAE中既有编码器又有解码器(生成器),同时隐变量分布又被近似编码为标准正态分布,因此VAE既是一个生成模型,又是一个特征提取器。在图像领域中,由于VAE生成的图片偏模糊,因此大家通常更关心VAE作为图像特征提取器的作用。提取特征都是为了下一步的任务准备的,而下一步的任务可能有很多,比如分类、聚类等。本文来关心“聚类”这个任务。
一般来说,用AE或者VAE做聚类都是分步来进行的,即先训练一个普通的VAE,然后得到原始数据的隐变量,接着对隐变量做一个K-Means或GMM之类的。但是这样的思路的整体感显然不够,而且聚类方法的选择也让我们纠结。本文介绍基于VAE的一个“一步到位”的聚类思路,它同时允许我们完成无监督地完成聚类和条件生成。
理论 #
一般框架 #
回顾VAE的loss(如果没印象请参考《变分自编码器(二):从贝叶斯观点出发》):
$$KL\Big(p(x,z)\Big\Vert q(x,z)\Big) = \iint p(z|x)\tilde{p}(x)\ln \frac{p(z|x)\tilde{p}(x)}{q(x|z)q(z)} dzdx\tag{1}$$
通常来说,我们会假设$q(z)$是标准正态分布,$p(z|x),q(x|z)$是条件正态分布,然后代入计算,就得到了普通的VAE的loss。
然而,也没有谁规定隐变量一定是连续变量吧?这里我们就将隐变量定为$(z, y)$,其中$z$是一个连续变量,代表编码向量;$y$是离散的变量,代表类别。直接把$(1)$中的$z$替换为$(z,y)$,就得到
$$KL\Big(p(x,z,y)\Big\Vert q(x,z,y)\Big) = \sum_y \iint p(z,y|x)\tilde{p}(x)\ln \frac{p(z,y|x)\tilde{p}(x)}{q(x|z,y)q(z,y)} dzdx\tag{2}$$
这就是用来做聚类的VAE的loss了。
分步假设 #
啥?就完事了?呃,是的,如果只考虑一般化的框架,$(2)$确实就完事了。
不过落实到实践中,$(2)$可以有很多不同的实践方案,这里介绍比较简单的一种。首先我们要明确,在$(2)$中,我们只知道$\tilde{p}(x)$(通过一批数据给出的经验分布),其他都是没有明确下来的。于是为了求解$(2)$,我们需要设定一些形式。一种选取方案为
$$p(z,y|x)=p(y|z)p(z|x),\quad q(x|z,y)=q(x|z),\quad q(z,y)=q(z|y)q(y)\tag{3}$$
代入$(2)$得到
$$KL\Big(p(x,z,y)\Big\Vert q(x,z,y)\Big) = \sum_y \iint p(y|z)p(z|x)\tilde{p}(x)\ln \frac{p(y|z)p(z|x)\tilde{p}(x)}{q(x|z)q(z|y)q(y)} dzdx\tag{4}$$
其实$(4)$式还是相当直观的,它分布描述了编码和生成过程:
1、从原始数据中采样到$x$,然后通过$p(z|x)$可以得到编码特征$z$,然后通过分类器$p(y|z)$对编码特征进行分类,从而得到类别;
2、从分布$q(y)$中选取一个类别$y$,然后从分布$q(z|y)$中选取一个随机隐变量$z$,然后通过生成器$q(x|z)$解码为原始样本。
具体模型 #
$(4)$式其实已经很具体了,我们只需要沿用以往VAE的做法:$p(z|x)$一般假设为均值为$\mu(x)$方差为$\sigma^2(x)$的正态分布,$q(x|z)$一般假设为均值为$G(z)$方差为常数的正态分布(等价于用MSE作为loss),$q(z|y)$可以假设为均值为$\mu_y$方差为1的正态分布,至于剩下的$q(y),p(y|z)$,$q(y)$可以假设为均匀分布(它就是个常数),也就是希望每个类大致均衡,而$p(y|z)$是对隐变量的分类器,随便用个softmax的网络就可以拟合了。
最后,可以形象地将$(4)$改写为
$$\mathbb{E}_{x\sim\tilde{p}(x)}\Big[-\log q(x|z) + \sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)} + KL\big(p(y|z)\big\Vert q(y)\big)\Big],\quad z\sim p(z|x) \tag{5}$$
其中$z\sim p(z|x)$是重参数操作,而方括号中的三项loss,各有各的含义:
1、$-\log q(x|z)$希望重构误差越小越好,也就是$z$尽量保留完整的信息;
2、$\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$希望$z$能尽量对齐某个类别的“专属”的正态分布,就是这一步起到聚类的作用;
3、$KL\big(p(y|z)\big\Vert q(y)\big)$希望每个类的分布尽量均衡,不会发生两个几乎重合的情况(坍缩为一个类)。当然,有时候可能不需要这个先验要求,那就可以去掉这一项。
实验 #
实验代码自然是Keras完成的了(^_^),在mnist和fashion-mnist上做了实验,表现都还可以。实验环境:Keras 2.2 + tensorflow 1.8 + Python 2.7。
代码实现 #
代码位于:https://github.com/bojone/vae/blob/master/vae_keras_cluster.py
其实注释应该比较清楚了,而且相比普通的VAE改动不大。可能稍微有难度的是$\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$这个怎么实现。首先我们代入
$$\begin{aligned}p(z|x)&=\frac{1}{\prod\limits_{i=1}^d\sqrt{2\pi\sigma_i^2(x)}}\exp\left\{-\frac{1}{2}\left\Vert\frac{z - \mu(x)}{\sigma(x)}\right\Vert^2\right\}\\
q(z|y)&=\frac{1}{(2\pi)^{d/2}}\exp\left\{-\frac{1}{2}\left\Vert z - \mu_y\right\Vert^2\right\}\end{aligned}\tag{6}$$
得到
$$\log \frac{p(z|x)}{q(z|y)}=-\frac{1}{2}\sum_{i=1}^d \log \sigma_i^2(x)-\frac{1}{2}\left\Vert\frac{z - \mu(x)}{\sigma(x)}\right\Vert^2 + \frac{1}{2}\left\Vert z - \mu_y\right\Vert^2 \tag{7}$$
注意其实第二项是多余的,因为重参数操作告诉我们$z = \varepsilon\otimes \sigma(x) + \mu(x),\,\varepsilon\sim \mathcal{N}(0,1)$,所以第二项实际上只是$-\Vert \varepsilon\Vert^2/2$,跟参数无关,所以$$\log \frac{p(z|x)}{q(z|y)}\sim -\frac{1}{2}\sum_{i=1}^d \log \sigma_i^2(x) + \frac{1}{2}\left\Vert z - \mu_y\right\Vert^2 \tag{8}$$
然后因为$y$是离散的,所以事实上$\sum_y p(y|z) \log \frac{p(z|x)}{q(z|y)}$就是一个矩阵乘法(相乘然后对某个公共变量求和,就是矩阵乘法的一般形式),用K.batch_dot实现。
其他的话,读者应该清楚普通的VAE的实现过程,然后才看本文的内容和代码,不然估计是一脸懵的。
mnist #
这里是mnist的实验结果图示,包括类内样本图示和按类采样图示。最后还简单估算了一下,以每一类对应的数目最多的那个真实标签为类标签的话,最终的test准确率大约有83%,对比这篇文章《Unsupervised Deep Embedding for Clustering Analysis》的结果(最高也是84%左右),感觉应该很不错了。
聚类图示 #
按类采样 #
fashion-mnist #
这里是fashion-mnist的实验结果图示,包括类内样本图示和按类采样图示,最终的test准确率大约有58.5%。
聚类图示 #
按类采样 #
总结 #
文章简单地实现了一下基于VAE的聚类算法,算法的特点就是一步到位,结合“编码”、“聚类”和“生成”三个任务同时完成,思想是对VAE的loss的一般化。
感觉还有一定的提升空间,比如式$(4)$只是式$(2)$的一个例子,还可以考虑更加一般的情况。代码中的encoder和decoder也都没有经过仔细调优,仅仅是验证想法所用。
转载到请包括本文地址:https://spaces.ac.cn/archives/5887
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Sep. 17, 2018). 《变分自编码器(四):一步到位的聚类方案 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/5887
@online{kexuefm-5887,
title={变分自编码器(四):一步到位的聚类方案},
author={苏剑林},
year={2018},
month={Sep},
url={\url{https://spaces.ac.cn/archives/5887}},
}
March 17th, 2019
你好,看了您的代码和博客,有一个问题没有理解,还望您能解答一下。
正如您在博客中描述的编码和生成过程,由x编码z再做分类得到c,然后由c采样z并经生成器恢复x,这个过程与“Variational Deep Embedding:An Unsupervised and Generative Approach to Clustering”这篇文章中图1描述是一致的。但是看您的代码似乎是由x编码z,然后利用p(z|x)采样得到z直接送入生成器恢复x,而分类部分似乎是单独的一块并没有把分类部分整合到x的端到端网络中。请您能进一步解释下吗
本文的代码就是严格按照你说的流程来实现的,对应的loss就是我目前的写法。
“由x编码z,然后利用p(z|x)采样得到z直接送入生成器恢复x”,只是loss的一部分,而整个模型的loss事实上要看成一个整体,所以“而分类部分似乎是单独的一块并没有把分类部分整合到x的端到端网络中”这个理解是错的,它不是单独的一块,它就是整个loss的一部分。
(你可以分步理解,但它就是一个整体,而不是两个不同loss的随意拼凑。)
不知作者是否有读过上面提到的那篇文章,我看下来感觉二者的过程是一样的,只不过你的分类概率用的是一个softmax网络,他是直接算出一个最优概率,不知道我的理解对不对,就是觉得好像二者是一样的,又好像有点不一样,他用的是高斯混合,您这里没有提到高斯混合,但其实也是用到了,这里就有点模糊,还希望作者能给予指点。还有一个疑问就是对于目标函数的第二项,作用是希望z能尽量对齐某个类别的“专属”的正态分布,就是这一步起到聚类的作用,直接从那个目标函数如何直接解读处这个作用啊?
为什么一定要知道它“叫什么”呢?我不知道是不是高斯混合,我只知道在文本中我很清晰地、理论与实验结合地完成了我要做的事情。对于每个类的先验分布,我使用的是条件高斯分布,条件为均值。
VAE的早期工作很多相互之间都很相似,花心思去确认是否等价是一件很没有意义的事情,你只需要知道自己学到了什么、学懂了没有就是了。
如果你对本文还有疑惑,欢迎继续留言,但我没有兴趣去比较本文跟已有论文的相似性,因为我也没有要去争个优先权的意思,不想浪费这个时间。
后来又对比了下感觉区别还是蛮大的。我提到的那篇文章是通过x的对数似然诱导出的变分下界(证据下界),然后通过优化证据下界来实现x的对数似然最大化。这里作者是目标优化KL散度。参考PRML变分推断这一章,其中讲到优化证据下界同时可以使KL散度为零,也就是说两种方法等价。但是不同的是,PRML和这里均是假设了一个所有变量的联合概率分布,而VAE和我说的那篇文章里假设的是一个条件概率分布,因此推导出来的结果很相似。我也不知道那种假设更加合理或者该如何选择,可能如作者说的很多工作有相似之处吧。希望各位有兴趣的话可以一起探讨。
联合分布本身由条件分布相乘得到
道理是这么个道理,您这里优化的是q(x,y,z)并分解成了q(x|z)q(z|y)q(y),那篇文章中优化的是q(y,z|x)并分解成了q(y|x)q(z|x),因此两种不同的分解意味着不同的模型。
June 9th, 2019
请问博主的这篇的想法是原创的,还是有借鉴其他论文,如果有其他论文能否给个参考。
本文是自己独立构思的,至于有没有类似工作,我不是很了解。
June 17th, 2019
博主您好,读完您的博文收益很深,您在最后总结的时候,提到“式(4)只是式(2)的一个例子,还可以考虑更加一般的情况”,这个一般情况可以从哪些地方考虑呢,希望您能指点一下。
比如$q(x|z)$意味着生成$x$的时候只用到了$z$,你可以考虑一般化的$q(x|z, y)$,即同时传入$z,y$来生成$x$。
August 14th, 2019
苏神您好,我有几个问题想请教您:
1.像公式(2)推导到公式(4)是怎么来的呢,有什么资料可以参考吗?
2.通过概率这种推导后,怎么转换到实际的编成呢?也就是怎么把最后的损失转换为我们可以计算的损失函数呢?
个人总感觉理论推导到实例化有很大的隔阂,而且很多书籍也没有提供这样的转换的类似参考,希望苏神能够指点一二,或者提供一些相应的参考资料。
1、(2)到(4)是举例,不是推导;也就是说(4)是(2)的一个更具体的例子罢了;
2、你需要清晰知道模型对应的数学公式,并且需要清晰知道(常见的)数学公式用框架怎么写出来;这是一个慢慢累积的过程,没有什么现成的参考资料。
August 31st, 2019
您好苏神,很感激您的工作。
请问能否具体解释一下【具体模型】一节,为什么loss2和loss3有那样的含义呢?
不知道你想表达什么意思,什么叫做“为什么loss2和loss3有那样的含义”?
October 6th, 2019
right = 0.
for i in range(10):
_ = np.bincount(y_train_[y_train_pred == i])
right += _.max()
这段统计的代码是不是有问题,不太对啊?
有什么问题?
这个是我搞错了
@苏剑林|comment-12127
# 定义损失函数
z_mean = tf.expand_dims(z_mean, 1)
z_log_var = tf.expand_dims(z_log_var, 1)
lamb = 2.5 # 这是重构误差的权重,它的相反数就是重构方差,越大意味着方差越小。
xent_loss = 0.5 * tf.reduce_mean((x - x_recon) ** 2, 0)
kl_loss = - 0.5 * (z_log_var - tf.square(z_prior_mean))
kl_loss = tf.reduce_mean(tf.matmul(tf.expand_dims(y, 1), kl_loss), 0)
cat_loss = tf.reduce_mean(y * tf.log(y + _EPSILON), 0)
vae_loss = lamb * tf.reduce_sum(xent_loss) + tf.reduce_sum(kl_loss) + tf.reduce_sum(cat_loss)
这个没有看太懂,是如何和公式(5)对应的呢?
这似乎不是我写的吧?
至于怎么对应,就是按文中解释的对应。
这个是您原始代码里拿到的,我理解这里面cat_loss是对应的公式5中KL(p(y|z)||q(y))吧,xent_loss是−loq(x|z),kl_loss对应的是中间您推到的那部分吧
原始代码如下:
# 建立模型
vae = Model(x, [x_recon, z_prior_mean, y])
# 下面一大通都是为了定义loss
z_mean = K.expand_dims(z_mean, 1)
z_log_var = K.expand_dims(z_log_var, 1)
lamb = 2.5 # 这是重构误差的权重,它的相反数就是重构方差,越大意味着方差越小。
xent_loss = 0.5 * K.mean((x - x_recon)**2, 0)
kl_loss = - 0.5 * (z_log_var - K.square(z_prior_mean))
kl_loss = K.mean(K.batch_dot(K.expand_dims(y, 1), kl_loss), 0)
cat_loss = K.mean(y * K.log(y + K.epsilon()), 0)
vae_loss = lamb * K.sum(xent_loss) + K.sum(kl_loss) + K.sum(cat_loss)
是的。并且该解释的在文章中已经解释了,剩下的还看不懂的话,则需要先补习vae的相关内容。
# 重参数层,相当于给输入加入噪声
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
x_recon = decoder(z)
y = classfier(z)
再请教您一个问题,如上面原始代码所述,就是关于代码中x_recon是直接从decoder(z)得到的,而进行decoder的z并不由 y得到的q(z|y),为什么不写成经过q(z|y)得到的呢?
该怎么做怎么写,不是拍脑袋而来的,而是根据公式$(5)\sim (8)$决定的。
October 8th, 2019
你好,请问涉及到聚类时,是怎么对y进行约束,从而使聚类效果更好呢?
什么样的约束?
请问整个实现过程怎么使聚类效果更好呢?
模型自己学会聚类,不用人工过多干扰聚类过程。如果要说聚类的依据是什么,那应该是最大似然吧。
October 10th, 2019
请问如何在目标函数中体现聚类的目标?
November 21st, 2019
你好,我想问问文章中提到的vae_keras_cluster.py代码我运行不出来呀,总是报错,好像是acc没有定义是吗,可以看下运行的结果截图吗
请先完全对齐各个软件版本再说。
December 24th, 2019
大神,您好,读了您写的VAE系列博客,受益匪浅,想请教本文中式(5)作为式(4)的一个具体实现,是如何得到的,能否进行详细的解释说明,十分感谢您的分享。
理解VAE以及本文$(4)$式的基础上,$(5)$式就是显然成立的,不能再详细了,因为它跟$(4)$式是一模一样的,甚至连变换都没有变换过。