聯邦學習中的模型聚合


論文[1]在聯邦學習的情景下引入了多任務學習,其采用的手段是使每個client/task節點的訓練數據分布不同,從而使各任務節點學習到不同的模型,且每個任務節點以及全局(global)的模型都由多個分量模型集成。該論文最關鍵與核心的地方在於將各任務節點學習到的模型進行聚合/通信,依據模型聚合方式的不同,可以將模型采用的算法分為client-server方法,和fully decentralized(完全去中心化)的方法(其實還有其他的聚合方法沒,如論文[3]提出的簇狀聚合方法,代碼參見[4]我們這里暫時略過),其中這兩種方法在具體實現上都可以替換為對代理損失函數的優化,不過我們這里暫時略過。

因為有多種任務聚合器(Aggregator)要實現,論文代碼(已開源在Github上,參見[2])采取的措施是先實現Aggregator抽象基類,實現好一些通用方法,並規定好抽象方法的接口,然后具體的任務聚合類繼承抽象基類,然后做具體的實現。

我們先來看任務聚合器(Aggregator)這一抽象基類

class Aggregator(ABC):
    r"""Aggregator的基類. `Aggregator`規定了client之間的通信"""
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=False,
            test_clients=None,
            verbose=0,
            seed=None,
            *args,
            **kwargs
    ):

        rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
        self.rng = random.Random(rng_seed) # 隨機數生成器
        self.np_rng = np.random.default_rng(rng_seed) # numpy隨機數生成器

        if test_clients is None:
            test_clients = []

        self.clients = clients #  List[Client]
        self.test_clients = test_clients #  List[Client]

        self.global_learners_ensemble = global_learners_ensemble # List[Learner]
        self.device = self.global_learners_ensemble.device


        self.log_freq = log_freq
        self.verbose = verbose
        # verbose: 調整輸出打印的冗余度(verbosity), 
        # `0` 表示quiet(無任何打印輸出), `1` 顯示日志, `2` 顯示所有局部日志; 默認是 `0`
        self.global_train_logger = global_train_logger
        self.global_test_logger = global_test_logger

        self.model_dim = self.global_learners_ensemble.model_dim # #模型特征維度

        self.n_clients = len(clients)
        self.n_test_clients = len(test_clients)
        self.n_learners = len(self.global_learners_ensemble)

        # 存儲為每個client分配的權重(權重為0-1之間的小數)
        self.clients_weights =\
            torch.tensor(
                [client.n_train_samples for client in self.clients],
                dtype=torch.float32
            )
        self.clients_weights = self.clients_weights / self.clients_weights.sum()

        self.sampling_rate = sampling_rate  #  clients在每一輪使用的比例,默認為`1.`
        self.sample_with_replacement = sample_with_replacement #對client進行采用是可重復還是無重復的,with_replacement=True表示可重復的,否則是不可重復的

        # 每輪迭代需要使用到的client個數
        self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))

        # 采樣得到的client列表
        self.sampled_clients = list()

        # 記載當前的迭代通信輪數
        self.c_round = 0 
        self.write_logs()

    @abstractmethod
    def mix(self): 
        """
        該方法用於完成各client之間的權重參數與通信操作
        """
        pass

    @abstractmethod
    def update_clients(self): 
        """
        該方法用於將所有全局分量模型拷貝到各個client,相當於boardcast操作
        """
        pass

    def update_test_clients(self):
        """
        將全局(gobal)的所有分量模型都拷貝到各個client上
        """

    def write_logs(self):
        """
        對全局(global)的train和test數據集的loss和acc做記錄
        需要對所有client的所有樣本做累加,然后除以所有client的樣本總數做平均。
        """

    def save_state(self, dir_path):
        """
        保存aggregator的模型state,。例如, `global_learners_ensemble`中每個分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每個client的 `learners_weights` (注意,這個權重不是模型內部的參數,而是進行繼承的時候對各個分量模型賦予的權重,包含train和test兩部分,以一個大小為n_clients(n_test_clients)× n_learners的numpy數組的格式,即`.npy` 文件)。
        """

    def load_state(self, dir_path):
        """
        加載aggregator的模型state,即save_state方法里保存的那些
        """

    def sample_clients(self):
        """
        對clients進行采樣,
        如果self.sample_with_replacement為True,則為可重復采樣,
        否則,則為不可重復采用。
        最終得到一個clients子集列表並賦予self.sampled_clients
        """

1.client-server 算法

這種方式的通信/聚合方法也稱中心化(centralized)方法,因為該方法在每一輪迭代最后將所有client的權重數據匯集到server節點。這種方法的優化迭代部分的偽代碼示意如下:
CV多任務學習

落實到具體代碼實現上,這種方法的Aggregator設計如下:

class CentralizedAggregator(Aggregator):
    r""" 標准的中心化Aggreagator
    所有clients在每一輪迭代末和average client完全同步.
    """
    def mix(self):
        self.sample_clients()

        # 對self.sampled_clients中每個client的參數進行優化
        for client in self.sampled_clients:
            # 相當於偽代碼第11行調用的LocalSolver函數
            client.step()

        # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
        # 相當於偽代碼第13行
        for learner_id, learner in enumerate(self.global_learners_ensemble):
            # 獲取所有client中對應learner_id的分量模型
            learners = [client.learners_ensemble[learner_id] for client in self.clients]
            # global模型的分量模型為所有client對應分量模型取平均,相當於偽代碼第14行
            average_learners(learners, learner, weights=self.clients_weights)

        # 將更新后的模型賦予所有clients,相當於偽代碼第5行的boardcast操作
        self.update_clients()

        # 通信輪數+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

    def update_clients(self):
        """
        此函數負責將所有全局分量模型拷貝到各個client,相當於偽代碼中第5行的boardcast操作
        """
        for client in self.clients:
            for learner_id, learner in enumerate(client.learners_ensemble):
                copy_model(learner.model, self.global_learners_ensemble[learner_id].model)

                if callable(getattr(learner.optimizer, "set_initial_params", None)):
                    learner.optimizer.set_initial_params(
                        self.global_learners_ensemble[learner_id].model.parameters()
                    )

2. fully decentralized(完全去中心化)算法

這種方法之所以被稱為去中心化的,因為該方法在每一輪迭代不需要所有client的權重數據匯集到一個特定的server節點,而只需要完成每個節點和其鄰居進行通信(參數共享)即可。這種方法的優化迭代部分的偽代碼示意如下:
CV多任務學習
落實到具體代碼實現上,這種方法的Aggregator設計如下:

class DecentralizedAggregator(Aggregator):
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            mixing_matrix,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=True,
            test_clients=None,
            verbose=0,
            seed=None):

        super(DecentralizedAggregator, self).__init__(
            clients=clients,
            global_learners_ensemble=global_learners_ensemble,
            log_freq=log_freq,
            global_train_logger=global_train_logger,
            global_test_logger=global_test_logger,
            sampling_rate=sampling_rate,
            sample_with_replacement=sample_with_replacement,
            test_clients=test_clients,
            verbose=verbose,
            seed=seed
        )

        self.mixing_matrix = mixing_matrix
        assert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"

    def update_clients(self):
        pass

    def mix(self):
        
        # 對各clients的模型參數進行優化
        for client in self.clients:
            client.step()

        # 存儲每個模型各參數混合的權重
        # 行對應不同的client,列對應單個模型中不同的參數
        # (注意:每個分量有獨立的mixing_matrix)
        mixing_matrix = torch.tensor(
            self.mixing_matrix.copy(),
            dtype=torch.float32,
            device=self.device
        )

        # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
        # 相當於偽代碼第14行
        for learner_id, global_learner in enumerate(self.global_learners_ensemble):
            # 用於將指定learner_id的各client的模型state讀出暫存
            state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]

            # 遍歷global模型中的各參數, key對應模型中參數的名稱
            for key, param in global_learner.model.state_dict().items():
                shape_ = param.shape
                models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)

                for ii, sd in enumerate(state_dicts):
                    # models_params的第ii個下標存儲的是第ii個client的(名為key的)參數
                    models_params[ii] = sd[key].view(1, -1) 

                # models_params的每一行是一個client的參數
                # @符號表示矩陣乘/矩陣向量乘
                # 故這里表示每個client參數是其他所有client參數的混合
                models_params = mixing_matrix @ models_params

                for ii, sd in enumerate(state_dicts):
                    # 將第ii個client的(名為key的)參數存入state_dicts中對應位置
                    sd[key] = models_params[ii].view(shape_)

            # 將更新好的參數從state_dicts存入各client節點的模型中
            for client_id, client in enumerate(self.clients):
                client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])

        # 通信輪數+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

參考文獻


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM