MLP_BP反向传播
实验要求
以三层感知机为例,使用反向传播算法更新MLP的权重和偏置项。
Define and as: where is the mean
vector of (the
output of the th sample from the
cth class), is the
mean vector of the output from all classes,
is the number of samples from
the cth class. Define the discriminative regularization term
and incorporate it into the objective function of the MLP:
符号说明
以经典的手写体识别任务为例,说明本次实验推导所用符号的含义:
其中 为激活函数,此处使用sigmoid函数即
导函数为
对正则项的理解
首先我们看到,这个题目中的正则项不同于常见的L1正则项或者L2正则项。那么它代表什么含义,又是怎么起到正则化的作用的呢?
在了解了线性判别分析之后,发现这个正则项与线性判别分析中所谓的”类内散度矩阵“和”类问散度矩阵"非常相似。而线性判别分析的核心思想便是”类内方差小、类间间隔大“,在这里也是如此。
从损失函数可以看出,,我们希望让类内散度矩阵 尽可能小,即同一类的样本尽量预测结果一致;同时 是类间散度矩阵,我们希望让它尽可能大,以让模型更好的”区别“开不同的样本。
为了下面分析的方便,对损失函数进行拆解。
容易证明,对于两个阶数都是 的矩阵 , 其中一个矩阵乘以
另一个矩阵的转置的迹, 本质是 两个矩阵对应位置的元素相乘并相加,
可以理解为向量的点积在矩阵上的推广, 即:
则对于题中的列向量也是如此,在损失函数中表现为各元素的平方和。
由此可以得到单个样本单个特征的损失函数:
对此公式符号的含义进行如下直观解释:
image-20221124230019479
注:上图所指 也为 ; 表示第 类上的预测值在所有样本上的平均。 ##
梯度下降求解
目标
首先我们需要清楚反向传播的目的:我们希望根据模型在样本上的表现结果调节模型,最小化损失函数以让其在训练集上的表现更好。
具体到神经网络,我们需要调节的是每一条边对应的权重或偏置,依据是损失函数对该层权重的偏导。直观一点说,偏导反映的是参数的微小变化对损失的影响。我们希望最小化损失函数,那么比如如果求出来对权重的偏导是正的,那么说明损失函数随权重增大而增大,那么就要让权重变小一点。
因此我们更新参数的方式如下: 其中 是权重和偏置所在的层数,对于三层感知机 。 为学习率。 为整个训练集大小,比如对于MNIST这个值为60000。
偏导求解
在上面定义了对于单个样本单个特征的损失函数。下面为简化叙述,采用逐样本进行偏导的求解。根据题目的含义我们应该是使用批量梯度下降法 进行更新,此时将对每个样本求得的偏导加起来求得总的 代入 中(而不是每个样本都使用 式进行更新)即可。
也就是
为类的个数,比如手写体识别中为10,数字 对应于类 。
阅读下面的求解过程时如果担心忘记符号对应的含义,可以将下图固定在屏幕上。
最后一层
从总体来看,对于最后一层的某个边 的权重更新是比较容易进行的。如图所示, , 为类别的个数。 , 为隐层包含的神经元个数。
由链式法则,有:
由于符号定义中第三层的神经元能够比较好的和第二层的区别开,故省略了表示层数的上标;而 和 的取值与均与特定的样本有关,故都保留了表示样本的下角标 。
我们仍旧可以直观的理解链式法则对应于参数更新的含义。我们要求的是损失函数 对 的敏感程度,而 能够直接影响的是 , 影响 , 才直接影响到损失函数。因此需要借助链式法则将这个”影响链“串起来。
取决于第二层所有神经元,但我们只需要关注与正在求偏导的边相关的节点:
故有 ,因此
就等于激活函数的导数,即 。重点是最后一项:
而 中包含 ,而且我们只关心 。 里面自然有一个 。不仅如此,别忘了 和 也都是与 相关的变量,比如 是这个样本所属的类在训练集中的数目,比如这个样本实际上是个"3",那么 就是训练集中"3"的数目。
自然对应 中的某一个。因此
同样的, 是训练集中样本数目。
由此我们可以继续求解:
将上面求得的结果代入对 求偏导的式子中:
最后,如果我们想要进行批量梯度下降,需要将所有训练集中样本的损失函数加起来求平均,进行一次更新:
由于三个部分均与 有关,所以没有可以提取的公因子,需要逐项累加。并且该式子与第 个样本所属类有关,因此也要根据样本情况代入相应的的 和
。
此时我们就可以代入 式(更新参数的方式)中,进行每条权重边的更新了。
对于最后一层的 对应偏置的边求解的方式和 类似。 , 和 已经求得,代入即可。
倒数第二层
倒数第二层某一条权重边记为 。
重点是 。它通过影响最后一层的所有节点去影响最终的损失函数。
不同于 , 自然是指最后一层的线性变换。
我们前面已经求得了。 。
根据递推关系其实也没有很复杂。
前面两项非常容易求得:
同样的,各个样本上的偏导累加并求平均,然后设置学习率进行梯度下降更新参数即可。
,其余部分也都已经知道了,因此代入即可按同样的方式更新 。
至此,所有的权重和偏置项都已经更新完毕。
思考
这个正则项与线性判别分析一样,思路很自然,数学表达也很严谨,确远不如L1或L2正则项用的广泛。虽然并没有基于此做过实验,但从推导的过程可以看到,相比L1或L2正则项直接对权重矩阵的范数求导,计算量明显要大的多。比如至少要预先把预测每一类对应的样本数和样本向量均值算出来,在求偏导时也要判断是属于哪一类,对性能肯定是有所损耗。