1 病態不獨立同分布(Non-IID)划分算法
在博客《分布式機器學習、聯邦學習、多智能體的區別和聯系》中我們提到論文[1]聯邦學習每個client具有數據不獨立同分布(Non-IID)的性質。
聯邦學習的論文多是用FEMNIST、CIFAR10、Shakespare、Synthetic等數據集對模型進行測試,這些數據集包括CV、NLP、普通分類/回歸這三種不同的任務。在單次實驗時,我們對原始數據集進行非獨立同分布(Non-IID) 的隨機采樣,為\(T\)個不同非任務生成\(T\)個不同分布的數據集。
我們在博客《聯邦學習:按Dirichlet分布划分Non-IID樣本》中已經介紹了按照Dirichlet分布划分non-IID樣本。
然而聯邦學習最開始采用的數據划分方法卻不是這種。這里我們重新回顧聯邦學習開山論文[1],它所采用的的是一種病態非獨立同分布(Pathological Non-IID)划分算法。以下我們以CIFAR10數據集的生成為例,來詳細地對該論文的數據集划分與采樣算法進行分析。
首先,如果選擇這種划分方式,需要指定則每個client上數據集所需要的標簽類型數做為超參, 該划分算法的函數原型一般如下:
def pathological_non_iid_split(dataset, n_classes, n_clients, n_classes_per_client):
我們解釋一下函數的參數,這里dataset
是torch.utils.Dataset
類型的數據集,n_classes
表示數據集里樣本分類數,n_client
表示client節點的數量,該函數返回一個由n_client
各自所需樣本索引組成的列表client_idcs
。
接下來我們看這個函數的內容。該函數完成的功能可以概括為:先將樣本按照標簽進行排序;再將樣本划分為n_client * n_classes_per_client
個shards(每個shard大小相等),對n_clients
中的每一個client分配n_classes_per_client
個shards(分配到client后,每個client中的shards要合並)。
首先,從數據集索引data_idcs
建立一個key為類別\(\{0,1,...,n\_classes-1\}\),value為對應樣本集索引列表的字典,這在實際上這就相當於按照label對樣本進行排序了。
data_idcs = list(range(len(dataset)))
label2index = {k: [] for k in range(n_classes)}
for idx in data_idcs:
_, label = dataset[idx]
label2index[label].append(idx)
sorted_idcs = []
for label in label2index:
sorted_idcs += label2index[label]
然后該函數將數據分為n_clients * n_classes_per_client
個獨立同分布的shards,每個shards大小相等。然后給n_clients
中的每一個client分配n_classes_per_client
個shards(分配到client后,每個client中的shards要合並),代碼如下:
def iid_divide(l, g):
"""
將列表`l`分為`g`個獨立同分布的group(其實就是直接划分)
每個group都有 `int(len(l)/g)` 或者 `int(len(l)/g)+1` 個元素
返回由不同的groups組成的列表
"""
num_elems = len(l)
group_size = int(len(l) / g)
num_big_groups = num_elems - g * group_size
num_small_groups = g - num_big_groups
glist = []
for i in range(num_small_groups):
glist.append(l[group_size * i: group_size * (i + 1)])
bi = group_size * num_small_groups
group_size += 1
for i in range(num_big_groups):
glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
return glist
n_shards = n_clients * n_classes_per_client
# 一共分成n_shards個獨立同分布的shards
shards = iid_divide(sorted_idcs, n_shards)
np.random.shuffle(shards)
# 然后再將n_shards拆分為n_client份
tasks_shards = iid_divide(shards, n_clients)
clients_idcs = [[] for _ in range(n_clients)]
for client_id in range(n_clients):
for shard in tasks_shards[client_id]:
# 這里shard是一個shard的數據索引(一個列表)
# += shard 實質上是在列表里並入列表
clients_idcs[client_id] += shard
最后,返回clients_idcs
return clients_idcs
2 算法測試與可視化呈現
接下來我們在EMNIST數據集上調用該函數進行測試,並進行可視化呈現。我們設client數量\(N=10\),每個client規定有兩種標簽類型樣本。
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset
n_clients = 10
n_classes_per_client = 2 # 每個client規定有兩種標簽類型
seed = 42
if __name__ == "__main__":
np.random.seed(seed)
train_data = datasets.CIFAR10(
root=".", download=True, train=True)
test_data = datasets.CIFAR10(
root=".", download=True, train=False)
classes = train_data.classes
n_classes = len(classes)
labels = np.concatenate(
[np.array(train_data.targets), np.array(test_data.targets)], axis=0)
dataset = ConcatDataset([train_data, test_data])
# 按照病態非獨立同分布對數據進行Non-IID划分
client_idcs = pathological_non_iid_split(
train_data, n_classes, n_clients, n_classes_per_client)
# 展示不同client的不同label的數據分布
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.show()
最終的可視化結果如下:
可以看到,62個類別標簽在不同client上的分布確實不同,且每個client上的標簽類別數量為兩個。
注意,這里算法保證的是每個client上標簽類別的近似數量為兩個,而不是保證每個client上標簽類別的絕對數量為兩個,因為該算法對兩個類別的話是直接將按標簽排序的樣本切分為n_client * 2
個塊,然后每個client分得2個塊。比如,如果我們不使用CIFAR10數據集,而是對EMNIST數據集(一共62個類別)進行划分,就會得到下面這樣的近似划分結果:
不過,該算法相比下面按照\(\alpha=1.0\)的Dirichlet分布划分的樣本(EMNIST數據集)仍然具有大大的不同。這證明我們的樣本划分算法是有效的。
參考
- [1] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.