聯邦學習:按病態非獨立同分布划分Non-IID樣本


1 病態不獨立同分布(Non-IID)划分算法

在博客《分布式機器學習、聯邦學習、多智能體的區別和聯系》中我們提到論文[1]聯邦學習每個client具有數據不獨立同分布(Non-IID)的性質。

聯邦學習的論文多是用FEMNIST、CIFAR10、Shakespare、Synthetic等數據集對模型進行測試,這些數據集包括CV、NLP、普通分類/回歸這三種不同的任務。在單次實驗時,我們對原始數據集進行非獨立同分布(Non-IID) 的隨機采樣,為\(T\)個不同非任務生成\(T\)個不同分布的數據集。

我們在博客《聯邦學習:按Dirichlet分布划分Non-IID樣本》中已經介紹了按照Dirichlet分布划分non-IID樣本。
然而聯邦學習最開始采用的數據划分方法卻不是這種。這里我們重新回顧聯邦學習開山論文[1],它所采用的的是一種病態非獨立同分布(Pathological Non-IID)划分算法。以下我們以CIFAR10數據集的生成為例,來詳細地對該論文的數據集划分與采樣算法進行分析。

首先,如果選擇這種划分方式,需要指定則每個client上數據集所需要的標簽類型數做為超參, 該划分算法的函數原型一般如下:

def pathological_non_iid_split(dataset, n_classes, n_clients, n_classes_per_client):

我們解釋一下函數的參數,這里datasettorch.utils.Dataset類型的數據集,n_classes表示數據集里樣本分類數,n_client表示client節點的數量,該函數返回一個由n_client各自所需樣本索引組成的列表client_idcs

接下來我們看這個函數的內容。該函數完成的功能可以概括為:先將樣本按照標簽進行排序;再將樣本划分為n_client * n_classes_per_client個shards(每個shard大小相等),對n_clients中的每一個client分配n_classes_per_client個shards(分配到client后,每個client中的shards要合並)。

首先,從數據集索引data_idcs建立一個key為類別\(\{0,1,...,n\_classes-1\}\),value為對應樣本集索引列表的字典,這在實際上這就相當於按照label對樣本進行排序了

    data_idcs = list(range(len(dataset)))
    label2index = {k: [] for k in range(n_classes)}
    for idx in data_idcs:
        _, label = dataset[idx]
        label2index[label].append(idx)

    sorted_idcs = []
    for label in label2index:
        sorted_idcs += label2index[label]

然后該函數將數據分為n_clients * n_classes_per_client 個獨立同分布的shards,每個shards大小相等。然后給n_clients中的每一個client分配n_classes_per_client個shards(分配到client后,每個client中的shards要合並),代碼如下:

    def iid_divide(l, g):
        """
        將列表`l`分為`g`個獨立同分布的group(其實就是直接划分)
        每個group都有 `int(len(l)/g)` 或者 `int(len(l)/g)+1` 個元素
        返回由不同的groups組成的列表
        """
        num_elems = len(l)
        group_size = int(len(l) / g)
        num_big_groups = num_elems - g * group_size
        num_small_groups = g - num_big_groups
        glist = []
        for i in range(num_small_groups):
            glist.append(l[group_size * i: group_size * (i + 1)])
        bi = group_size * num_small_groups
        group_size += 1
        for i in range(num_big_groups):
            glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
        return glist


    n_shards = n_clients * n_classes_per_client
    # 一共分成n_shards個獨立同分布的shards
    shards = iid_divide(sorted_idcs, n_shards)
    np.random.shuffle(shards)
    # 然后再將n_shards拆分為n_client份
    tasks_shards = iid_divide(shards, n_clients)

    clients_idcs = [[] for _ in range(n_clients)]
    for client_id in range(n_clients):
        for shard in tasks_shards[client_id]:
            # 這里shard是一個shard的數據索引(一個列表)
            # += shard 實質上是在列表里並入列表
            clients_idcs[client_id] += shard 

最后,返回clients_idcs

    return clients_idcs

2 算法測試與可視化呈現

接下來我們在EMNIST數據集上調用該函數進行測試,並進行可視化呈現。我們設client數量\(N=10\),每個client規定有兩種標簽類型樣本。

import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset


n_clients = 10
n_classes_per_client = 2  # 每個client規定有兩種標簽類型
seed = 42


if __name__ == "__main__":
    np.random.seed(seed)
    train_data = datasets.CIFAR10(
        root=".", download=True, train=True)
    test_data = datasets.CIFAR10(
        root=".", download=True, train=False)

    classes = train_data.classes
    n_classes = len(classes)

    labels = np.concatenate(
        [np.array(train_data.targets), np.array(test_data.targets)], axis=0)
    dataset = ConcatDataset([train_data, test_data])

    # 按照病態非獨立同分布對數據進行Non-IID划分
    client_idcs = pathological_non_iid_split(
        train_data, n_classes, n_clients, n_classes_per_client)

    # 展示不同client的不同label的數據分布
    plt.figure(figsize=(12, 8))
    label_distribution = [[] for _ in range(n_classes)]
    for c_id, idc in enumerate(client_idcs):
        for idx in idc:
            label_distribution[labels[idx]].append(c_id)

    plt.hist(label_distribution, stacked=True,
             bins=np.arange(-0.5, n_clients + 1.5, 1),
             label=classes, rwidth=0.5)
    plt.xticks(np.arange(n_clients), ["Client %d" %
                                      c_id for c_id in range(n_clients)])
    plt.xlabel("Client ID")
    plt.ylabel("Number of samples")
    plt.legend(loc="upper right")
    plt.title("Display Label Distribution on Different Clients")
    plt.show()

最終的可視化結果如下:

深度多任務學習實例1

可以看到,62個類別標簽在不同client上的分布確實不同,且每個client上的標簽類別數量為兩個。

注意,這里算法保證的是每個client上標簽類別的近似數量為兩個,而不是保證每個client上標簽類別的絕對數量為兩個,因為該算法對兩個類別的話是直接將按標簽排序的樣本切分為n_client * 2個塊,然后每個client分得2個塊。比如,如果我們不使用CIFAR10數據集,而是對EMNIST數據集(一共62個類別)進行划分,就會得到下面這樣的近似划分結果:

深度多任務學習實例1

不過,該算法相比下面按照\(\alpha=1.0\)的Dirichlet分布划分的樣本(EMNIST數據集)仍然具有大大的不同。這證明我們的樣本划分算法是有效的。

深度多任務學習實例1

參考

  • [1] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM