1 混合分布(Mixture Distribution)划分算法
我們在博文《聯邦學習:按病態獨立同分布划分Non-IID樣本》中學習了聯邦學習開山論文[1]中按照病態獨立同分布(Pathological Non-IID)划分樣本。 在上一篇博文《聯邦學習:按Dirichlet分布划分Non-IID樣本》中我們也已經提到了按照Dirichlet分布划分聯邦學習Non-IID數據集的一種算法。下面讓我們來看按Dirichlet分布划分數據集的另外一種變種,即按混合分布划分Non-IID樣本,該方法為論文[2]中首次提出。
該論文采取了一個重要的假設,那就是雖然聯邦學習每個client的數據是Non-IID的,但我們假設每個client的數據都來自於某個混合分布(混合成分個數\(K\)為超參數可調)。
其中\(t\)意思為第\(t\)個client,\(z_{tk}\)為(不可觀測的)隱變量(latent variable),意為第\(t\)個client中的數據來自成分\(k\)的概率。第\(t\)個client的某個樣本點\(x\)進行生成時,會從\(K\)個成分中選擇一個成分\(p(x|\theta_{k})\)進行采樣,選擇該成分的概率為\(\alpha_{tk}\)。
形象化的展示圖片如下:
有了這個假設, 那么每個client的數據都可以視為來自這三個分布的數據的混合(每個client的Non-IID區別只是混合比例系數各不相同而已,下面我們提到混合比例系數由Dirichlet分布隨機生成),那我們相當於假定了每個client數據間的一種"相似性",即在各節點數據表面的Non-IID(\(p(x|\theta_t)\))中其實潛藏IID的成分(\(p(x|\theta_{k}),k=1,2,..K\))。經過我的實驗,一旦這樣划分數據,那么對於基准的個性化聯邦學習算法都會提升精度, 但是[2]作者提出了一種基於子模型集成的算法來更加充分地利用這種相似性。比如,假設一個client一共有A、B、C這3個子成分, 那么我們就設計三個子模型分別對這些成分進行學習,每個模型的參數可以作為成分數據分布參數的一種體現。對於隱變量\(z_{tk}\)(做為子模型加權使用),作者設計了EM算法來進行推斷。
注意,這里作者的思想讓我們聯想到高斯混合分布。高斯混合分布就假設每個節點的數據采樣自高斯混合分布中的一個成分(對應一個聚類簇),而經典的高斯混合聚類就是要確定每個節點和簇的的對應關系(並推斷出隱變量系數), 可以參見我的博客《統計學習:EM算法及其在高斯混合模型(GMM)中的應用》。
接下來我們來看這個划分算法的函數如何設計。除了常規Dirichlet划分算法所要求的n_clients、n_classes、alpha等, 它還有一個專門的n_clusters參數,表示混合成分個數。我們來看函數原型:
def mixture_distribution_split_noniid(dataset, n_classes, n_clients, n_clusters, alpha, seed):
我們解釋一下函數的參數,這里dataset是torch.utils.Dataset類型的數據集,n_classes表示數據集里樣本分類數,n_clusters是簇的個數(后面會解釋其含義,如果設置為-1,則就默認n_clusters=n_classes,即每個簇對應一個標簽類別),alpha 為Dirichlet分布參數,用於控制clients之間的數據diversity(Non-IID多樣性)。seed為自定義的隨機數種子。該函數返回一個由n_client個client所需的樣本索引組成的列表組成的列表client_idcs。
接下來我們看這個函數的內容。這個函數的內容可以概括為:先將所有類別不重疊地划分為n_clusters個簇(每個簇對應一個不同的標簽分布,體現為標簽不重疊);再對每個簇c,將樣本按照Non-IID划分給不同的clients(每個client的樣本數量按照dirichlet分布來確定)。
首先,我們判斷n_clusters的數量,如果為-1,則默認每一個cluster對應一個數據class:
if n_clusters == -1:
n_clusters = n_classes
然后將打亂后的標簽集合\(\{0,1,...,n\_classes-1\}\)分為n_clusters個簇。注意,這就意為着每個簇對應的標簽集合沒有重疊,也就是說各個簇之間的樣本數據是Non-IID的。
all_labels = list(range(n_classes))
rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
rng = random.Random(rng_seed)
np.random.shuffle(all_labels)
def avg_divide(l, g):
"""
將列表`l`分為`g`個獨立同分布的group(其實就是直接划分)
每個group都有 `int(len(l)/g)` 或者 `int(len(l)/g)+1` 個元素
返回由不同的groups組成的列表
"""
num_elems = len(l)
group_size = int(len(l) / g)
num_big_groups = num_elems - g * group_size
num_small_groups = g - num_big_groups
glist = []
for i in range(num_small_groups):
glist.append(l[group_size * i: group_size * (i + 1)])
bi = group_size * num_small_groups
group_size += 1
for i in range(num_big_groups):
glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
return glist
clusters_labels = avg_divide(all_labels, n_clusters)
然后再根據上面划分好的label集合建立key為label, value為簇id(group_idx)的字典,
label2cluster = dict() # maps label to its cluster
for group_idx, labels in enumerate(clusters_labels):
for label in labels:
label2cluster[label] = group_idx
接着獲取數據集的索引
data_idcs = list(range(len(dataset)))
之后,我們將根據樣本的label和前面建立的label->cluster映射,再將樣本划分到對應簇里。
# 記錄每個cluster大小的向量
clusters_sizes = np.zeros(n_clusters, dtype=int)
# 存儲每個cluster對應的數據索引
clusters = {k: [] for k in range(n_clusters)}
for idx in data_idcs:
_, label = dataset[idx]
# 由樣本數據的label先找到其cluster的id
group_id = label2cluster[label]
# 再將對應cluster的大小+1
clusters_sizes[group_id] += 1
# 將樣本索引加入其cluster對應的列表中
clusters[group_id].append(idx)
# 將每個cluster對應的樣本索引列表打亂
for _, cluster in clusters.items():
rng.shuffle(cluster)
我們已經得到了屬於每個cluster的樣本索引,接着我們按照Dirichlet分布再將每個cluster中的樣本Non-IID地划分到各client上去。
# 記錄某個cluster的樣本分到某個client上的數量
clients_counts = np.zeros((n_clusters, n_clients), dtype=np.int64)
# 遍歷每一個cluster
for cluster_id in range(n_clusters):
# 對每個client賦予一個滿足dirichlet分布的權重,用於該cluster樣本的分配
weights = np.random.dirichlet(alpha=alpha * np.ones(n_clients))
# np.random.multinomial 表示投擲骰子clusters_sizes[cluster_id](該cluster中的樣本數)次,落在各client上的權重依次是weights
# 該函數返回落在各client上各多少次,也就對應着各client應該分得來自該cluster的樣本數
clients_counts[cluster_id] = np.random.multinomial(clusters_sizes[cluster_id], weights)
# 對每一個cluster上的每一個client的計數次數進行前綴(累加)求和,
# 相當於最終返回的是每一個cluster中按照client進行划分的樣本分界點下標
clients_counts = np.cumsum(clients_counts, axis=1)
然后,我們根據上面已經得到的屬於各cluster的樣本集合,和各cluster中樣本分到各client中的情況(我們已經得到了每一個cluster中按照client進行划分的樣本分界點下標),合並歸納得到每一個client中分得的樣本情況。
def split_list_by_idcs(l, idcs):
"""
將列表`l` 划分為長度為 `len(idcs)` 的子列表
第`i`個子列表從下標 `idcs[i]` 到下標`idcs[i+1]`
(從下標0到下標`idcs[0]`的子列表另算)
返回一個由多個子列表組成的列表
"""
res = []
current_index = 0
for index in idcs:
res.append(l[current_index: index])
current_index = index
return res
clients_idcs = [[] for _ in range(n_clients)]
for cluster_id in range(n_clusters):
# cluster_split為一個cluster中按照client划分好的樣本
cluster_split = split_list_by_idcs(clusters[cluster_id], clients_counts[cluster_id])
# 將每一個client的樣本累加上去
for client_id, idcs in enumerate(cluster_split):
clients_idcs[client_id] += idcs
最后,我們返回每個client對應的樣本索引:
return clients_idcs
2 算法測試與可視化呈現
接下來我們在EMNIST數據集上調用該函數進行測試,並進行可視化呈現。我們設client數量\(N=10\),混合成分個數為3,Dirichlet概率分布的參數向量\(\bm{\alpha}\)滿足\(\alpha_i=0.4,\space i=1,2,...N\):
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset
n_clients = 10
n_components = 3
dirichlet_alpha = 1.0
seed = 42
if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
train_data = datasets.CIFAR10(
root=".", download=True, train=True)
test_data = datasets.CIFAR10(
root=".", download=True, train=False)
classes = train_data.classes
n_classes = len(classes)
labels = np.concatenate(
[np.array(train_data.targets), np.array(test_data.targets)], axis=0)
dataset = ConcatDataset([train_data, test_data])
client_idcs = mixture_distribution_split_noniid(
train_data, n_classes, n_clients, n_components, dirichlet_alpha, seed)
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.show()
最終的可視化結果如下:
可以看到,62個類別標簽在不同client上的分布雖然不同,但相對下面的完全基於Dirichlet的樣本划分算法(\(\alpha=1.0\)),每個client之間的標簽類別分布顯得"更加相似",即看得出來都來自於一個混合分布,這證明我們的混合分布樣本划分算法是有效的。
最后附上完整代碼:
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset
n_clients = 10
n_components = 3
dirichlet_alpha = 1.0
seed = 42
def mixture_distribution_split_noniid(dataset, n_classes, n_clients, n_clusters, alpha, seed):
if n_clusters == -1:
n_clusters = n_classes
all_labels = list(range(n_classes))
rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
rng = random.Random(rng_seed)
np.random.shuffle(all_labels)
def avg_divide(l, g):
num_elems = len(l)
group_size = int(len(l) / g)
num_big_groups = num_elems - g * group_size
num_small_groups = g - num_big_groups
glist = []
for i in range(num_small_groups):
glist.append(l[group_size * i: group_size * (i + 1)])
bi = group_size * num_small_groups
group_size += 1
for i in range(num_big_groups):
glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
return glist
clusters_labels = avg_divide(all_labels, n_clusters)
label2cluster = dict()
for group_idx, labels in enumerate(clusters_labels):
for label in labels:
label2cluster[label] = group_idx
data_idcs = list(range(len(dataset)))
clusters_sizes = np.zeros(n_clusters, dtype=int)
clusters = {k: [] for k in range(n_clusters)}
for idx in data_idcs:
_, label = dataset[idx]
group_id = label2cluster[label]
clusters_sizes[group_id] += 1
clusters[group_id].append(idx)
for _, cluster in clusters.items():
rng.shuffle(cluster)
clients_counts = np.zeros((n_clusters, n_clients), dtype=np.int64)
for cluster_id in range(n_clusters):
weights = np.random.dirichlet(alpha=alpha * np.ones(n_clients))
clients_counts[cluster_id] = np.random.multinomial(clusters_sizes[cluster_id], weights)
clients_counts = np.cumsum(clients_counts, axis=1)
def split_list_by_idcs(l, idcs):
res = []
current_index = 0
for index in idcs:
res.append(l[current_index: index])
current_index = index
return res
clients_idcs = [[] for _ in range(n_clients)]
for cluster_id in range(n_clusters):
cluster_split = split_list_by_idcs(clusters[cluster_id], clients_counts[cluster_id])
for client_id, idcs in enumerate(cluster_split):
clients_idcs[client_id] += idcs
return clients_idcs
if __name__ == "__main__":
random.seed(seed)
np.random.seed(seed)
train_data = datasets.CIFAR10(
root=".", download=True, train=True)
test_data = datasets.CIFAR10(
root=".", download=True, train=False)
classes = train_data.classes
n_classes = len(classes)
labels = np.concatenate(
[np.array(train_data.targets), np.array(test_data.targets)], axis=0)
dataset = ConcatDataset([train_data, test_data])
client_idcs = mixture_distribution_split_noniid(
train_data, n_classes, n_clients, n_components, dirichlet_alpha, seed)
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.show()
參考
- [1] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.
- [2] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
