1 導引
我們在《Python中的隨機采樣和概率分布(二)》介紹了如何用Python現有的庫對一個概率分布進行采樣,其中的Dirichlet分布大家一定不會感到陌生,這篇博客我們來更詳細地介紹Dirichlet分布的性質及其在聯邦學習領域的應用。
2 Dirichlet分布及其性質
Dirichlet分布[1]是定義在\(\mathbb{R}^N\)上的概率密度。Dirichlet分布以度量\(\bm{u}\)(所有系數\(\bm{u}_i>0\)的一個向量)為參數,可將它寫為\(\bm{u}=\alpha \bm{m}\),這里\(\bm{m}\)是在\(N\)個分量上的歸一化度量(\(\sum_{i=1}^N m_i = 1\), \(m_i > 0\)),且\(\alpha\)是一個正數。Dirichlet分布的概率密度函數由下式給出:
注意,對\(\bm{X} = (X_1, \cdots, X_N)\sim \text{Dir}(\alpha \bm{m})\),有\(X_i>0 , \sum_{i=1}^N X_i = 1\)。
向量\(\bm{m}\)是隨機向量\(\bm{X}\)的期望:
在物理意義上,Dirichlet分布中參數\(\alpha\)的作用主要體現在兩個方面。首先,\(\alpha\)度量了這個分布的銳度(sharpness),也即測量我們分布中的典型樣本\(\bm{X}\)與其均值\(\bm{m}\)相差多遠,就像高斯分布中精度\(\tau=1/\sigma^2\)度量了樣本與它的均值偏差多遠一樣。一個大的\(\alpha\)值會使得\(\bm{X}\)的分布在\(\bm{m}\)附近急劇出現尖峰(后文我們會提到,在聯邦數據划分中,這將導致不同標簽在客戶端的分布更為同構)。下圖就體現了\(\alpha\)對\(\bm{X}\)分布的影響:
注意我們這里是從滿足\(N=3\)的分布中采樣1000個3維樣本點,兩個軸表示\(X_1\)和\(X_2\),\(X_3\)在可視化中並不使用。
這里附上可視化的代碼,感興趣的童鞋可下來自行嘗試:
import numpy as np
import matplotlib.pyplot as plt
us = [(0.1, 0.1, 0.1), (1, 1, 1), (10, 10, 10)] # 3組不同的u=alpha*m參數
points = [[] for i in range(3)]
for i in range(3):
points[i] = np.random.dirichlet(us[i], size=100)
xs, ys = [[] for i in range(3)], [[] for i in range(3)]
for i in range(3):
xs[i], ys[i], _ = list(zip(*points[i]))
fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for i in range(3):
axs[i].set_title(f"$αm={us[i]}$")
axs[i].scatter(xs[i], ys[i])
axs[i].set_xlabel("$X_1$")
axs[i].set_ylabel("$X_2$")
plt.suptitle(r"The Display of $X_1, X_2$ in $(X_1, X_2, X_3)$")
plt.show()
3 Dirichlet分布在聯邦學習中的應用
3.1 划分不獨立同分布(Non-IID)數據集
我們在聯邦學習中,經常會假設不同client間的數據集不滿足獨立同分布(Non-IID)。那么我們如何將一個現有的數據集按照Non-IID划分呢?我們知道帶標簽樣本的生成分布看可以表示為\(p(\bm{x}, y)\),我們進一步將其寫作\(p(\bm{x}, y)=p(\bm{x}|y)p(y)\)。其中如果要估計\(p(\bm{x}|y)\)的計算開銷非常大,但估計\(p(y)\)的計算開銷就很小。所有我們按照樣本的標簽分布來對樣本進行Non-IID划分是一個非常高效、簡便的做法。
總而言之,我們采取的算法思路是盡量讓每個client上的樣本標簽分布不同。我們設有\(K\)個類別標簽,\(N\)個client,每個類別標簽的樣本需要按照不同的比例划分在不同的client上。我們設矩陣\(\bm{X}\in \mathbb{R}^{K*N}\)為類別標簽分布矩陣,其行向量\(\bm{x}_k\in \mathbb{R}^N\)表示類別\(k\)在不同client上的概率分布向量(每一維表示\(k\)類別的樣本划分到不同client上的比例),該隨機向量就采樣自Dirichlet分布(第一次采用Dirichlet分布來划分數據集的論文為《Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification》[2])。
據此,我們可以寫出以下的划分算法:
def dirichlet_split_noniid(train_labels, alpha, n_clients):
'''
按照參數為alpha的Dirichlet分布將樣本索引集合划分為n_clients個子集
'''
n_classes = train_labels.max()+1
# (K, N) 類別標簽分布矩陣X,記錄每個類別划分到每個client去的比例
label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
# (K, ...) 記錄K個類別對應的樣本索引集合
class_idcs = [np.argwhere(train_labels == y).flatten()
for y in range(n_classes)]
# 記錄N個client分別對應的樣本索引集合
client_idcs = [[] for _ in range(n_clients)]
for k_idcs, fracs in zip(class_idcs, label_distribution):
# np.split按照比例fracs將類別為k的樣本索引k_idcs划分為了N個子集
# i表示第i個client,idcs表示其對應的樣本索引集合idcs
for i, idcs in enumerate(np.split(k_idcs,
(np.cumsum(fracs)[:-1]*len(k_idcs)).
astype(int))):
client_idcs[i] += [idcs]
client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
return client_idcs
其中np.random.dirichlet
函數的具體用法大家可以參見我的上一篇博客《Python中的隨機采樣和概率分布(二)》和numpy文檔《numpy.random.dirichlet函數》[3],此處不再贅述。
3.2 算法測試與可視化呈現
接下來我們在EMNIST數據集上調用該函數進行測試,並進行可視化呈現。我們設client數量\(N=10\),Dirichlet概率分布的參數\(\alpha=1.0\)(也是我們聯邦學習常用的設置),\(\bm{m}\in \mathbb{R}^N\)在我們這里表示每個client上各類型標簽數量的先驗分布,我們規定是均勻分布\(\bm{m}= (1, 1, \cdots, 1)\)(注意,因為有\(\alpha\)這個縮放因子在,所以是否真的歸一化了無所謂的,只要\(\bm{m}\)每個維度相等,那就可以說明每個client上各類型標簽數量的先驗分布是均勻分布)。數據集划分的可視化呈現如下:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset
n_clients = 10
dirichlet_alpha = 1.0
seed = 42
if __name__ == "__main__":
np.random.seed(seed)
train_data = datasets.EMNIST(
root=".", split="byclass", download=True, train=True)
test_data = datasets.EMNIST(
root=".", split="byclass", download=True, train=False)
classes = train_data.classes
n_classes = len(classes)
labels = np.concatenate(
[np.array(train_data.targets), np.array(test_data.targets)], axis=0)
dataset = ConcatDataset([train_data, test_data])
# 我們讓每個client不同label的樣本數量不同,以此做到Non-IID划分
client_idcs = dirichlet_split_noniid(
labels, alpha=dirichlet_alpha, n_clients=n_clients)
# 展示不同label划分到不同client的情況
plt.figure(figsize=(12, 8))
plt.hist([labels[idc]for idc in client_idcs], stacked=True,
bins=np.arange(min(labels)-0.5, max(labels) + 1.5, 1),
label=["Client {}".format(i) for i in range(n_clients)],
rwidth=0.5)
plt.xticks(np.arange(n_classes), train_data.classes)
plt.xlabel("Label type")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.show()
最終的可視化結果如下:
可以看到,62個類別標簽在不同client上的分布確實不同,證明我們的樣本划分算法是有效的。
我們嘗試將\(\alpha\)設置為\(0.1\),可以看到標簽分布的異構程度確實有所加大(結合我們前面所講的Dirichlet分布性質,也就是表示標簽概率分布的樣本點變得分散):
最后,如果我們想將\(x\)軸變為client,\(y\)軸變為標簽類別,即更明確地可視化不同client上的標簽分布情況,我們可以將數據可視化部分的代碼修改如下:
# 展示不同client上的label分布
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.show()
此時我們可以看到不同client上的標簽分布情況如下圖所示:
這里有個很尷尬問題:類別數量太多,導致圖右邊的圖例放不下了。因此建議如果采用這種可視化方法的話最好選擇類別數量少的數據集,比如CIFAR10。
參考
- [1] MacKay D J C, Mac Kay D J C. Information theory, inference and learning algorithms[M]. Cambridge university press, 2003.(chapter 23)
- [2] Hsu T M H, Qi H, Brown M. Measuring the effects of non-identical data distribution for federated visual classification[J]. arXiv preprint arXiv:1909.06335, 2019.
- [3] 《numpy.random.dirichlet函數》