一、阐述了联邦学习的诞生背景:
在当前数据具有价值,并且需要被保护,数据分布为non-IID情况下,需要提出一个框架来进行行之有效的训练,这也是联邦学习诞生的原因;
二、论文的相关工作:
首先,论文阐述了联邦学习所适用的领域:
1.数据集应该具有较大隐私,所以无法上传;
2.对于有监督学习下的任务,可以很轻易地判断其标签;
随后,论文举了两个基本例子:
1.典型的图像分类:根据学习用户以往的浏览照片类型来判断能够查询哪些照片;
2.典型的语言模型:典型的词语预测系统,通过以往的记录进行分析;
这两个例子和论文中所提到的适用领域不谋而合,因为:
1.类型可以通过用户标记来定义;
2.对于不同用户的习惯,数据分布很可能不同。例如官方语言和俚语,Flickr照片和手机照片;
后续论文通过这两个问题使用两种不同的网络模型来进行测试。
图像分类——前馈深层网络;
语言模型——LSTM;
通过这两个实验来进行联邦学习隐私又是和降低大型数据通信成本方面的测试;
最后,论文也论证了联邦优化较于分布式优化的差别和实际会遇到的问题。
主要来说,联邦优化和分布式优化的区别在于数据集上:
1.联邦优化数据集为non-IID,任何一个节点的数据集分布无法代表整体;
2.数据分布不平均,有的节点数据集多,有的节点数据集少,也就是不平衡属性;
而我们实际中会遇到的问题(本篇论文讨论理想状态下,并不做过多涉及):
1.客户端数据回添加或者删除;
2.客户端有可能不发送数据或者发的数据有问题;
3.客户端可用性可能会因为数据的分布的不同形式而受到影响(美式英语语音和英式英语语音可能会混杂在一起);
三、详细公式推导过程以及算法流程:
一般性的前提条件:
假设客户结点固定,为K个,并且每个客户节点有固定的本地数据集。每一轮开始,选择客户的随即分数C,将服务器当前的全局参数发送给每个客户端,每个客户端基于全局状态+本地数据集进行本地计算,将更新发送给服务器,服务器更新并且用于全局状态,重复该过程。
基本的数学表示:
对于整个学习下的目标函数,应该为
其中,L代表loss函数。
这里稍微提一下,自己初次看有点懵逼,现在发现是把神经网络全部忘光了。
【补充】:
这是一个典型的神经网络的的非凸损失函数。
其中n代表样本点哥鼠,fi(w)代表根据每个样本点i算出来的损失函数,最后进行整个求均值,个人以依稀记得每轮计算整个数据集的f(w),来进行更新;
而对于 联邦学习下的目标函数,也是神经网络下的改版,其实就是区分了个结点而已。
其中,K为节点个数总共为K个。nk代表结点k中的训练集样本个数;
所以总而言之就是把每个epoch的目标函数变成每个节点的目标函数的加和。对于单个结点内的局部更新,和上面的神经网络的目标函数不变。
具体算法的步骤:
总体还是用的神经网络那一套,使用每个epoch所得到的w矩阵参数来更新下一个epoch。
只不过设计一个全局w的问题;
如上图所示;
其实思路很简单:
1.通过在每一个批次中选取部分节点,进行一个epoch的训练,之后每个结点上传服务器。
2.服务器将所有的w,进行加和求均值得到新的w,在下发给每个结点。
3.每个结点将下发的结点代替上一个epoch算出的w,进行新的epoch的训练。
重复上述三步直到服务器确定w收敛为止。