[源碼解析] 深度學習分布式訓練框架 horovod (21) --- 之如何恢復訓練
0x00 摘要
本文以 PyTorch on Horovod 為切入點,分析一下 Horovod 彈性訓練的恢復流程,具體涉及知識點有:
ElasticSampler與PyTorch 原生DistributedSampler 的區別,Horovod 彈性訓練如何恢復等。
本系列其他文章鏈接如下:
[源碼解析] 深度學習分布式訓練框架 Horovod (1) --- 基礎知識
[源碼解析] 深度學習分布式訓練框架 horovod (2) --- 從使用者角度切入
[源碼解析] 深度學習分布式訓練框架 horovod (3) --- Horovodrun背后做了什么
[源碼解析] 深度學習分布式訓練框架 horovod (4) --- 網絡基礎 & Driver
[源碼解析] 深度學習分布式訓練框架 horovod (5) --- 融合框架
[源碼解析] 深度學習分布式訓練框架 horovod (6) --- 后台線程架構
[源碼解析] 深度學習分布式訓練框架 horovod (7) --- DistributedOptimizer
[源碼解析] 深度學習分布式訓練框架 horovod (8) --- on spark
[源碼解析] 深度學習分布式訓練框架 horovod (9) --- 啟動 on spark
[源碼解析] 深度學習分布式訓練框架 horovod (10) --- run on spark
[源碼解析] 深度學習分布式訓練框架 horovod (11) --- on spark --- GLOO 方案
[源碼解析] 深度學習分布式訓練框架 horovod (12) --- 彈性訓練總體架構
[源碼解析] 深度學習分布式訓練框架 horovod (13) --- 彈性訓練之 Driver
[源碼解析] 深度學習分布式訓練框架 horovod (14) --- 彈性訓練發現節點 & State
[源碼解析] 深度學習分布式訓練框架 horovod (15) --- 廣播 & 通知
[源碼解析] 深度學習分布式訓練框架 horovod (16) --- 彈性訓練之Worker生命周期
[源碼解析] 深度學習分布式訓練框架 horovod (17) --- 彈性訓練之容錯
[源碼解析] 深度學習分布式訓練框架 horovod (18) --- kubeflow tf-operator
[源碼解析] 深度學習分布式訓練框架 horovod (17) --- 彈性訓練之容錯
[源碼解析] 深度學習分布式訓練框架 horovod (18) --- kubeflow tf-operator
[源碼解析] 深度學習分布式訓練框架 horovod (19) --- kubeflow MPI-operator
[源碼解析] 深度學習分布式訓練框架 horovod (20) --- Elastic Training Operator
0x01 總論
本文緣起於一個兄弟的留言:
請問在彈性訓練中,如果節點數目發生變化,數據怎么重新划分呢?比如一個epoch還沒有進行完,這時添加了新節點,新數據重新划分的話,當前內存中用舊數據訓練的模型還有效嗎?
我恰好在分析PyTorch分布式的時候也有類似疑問,所以就回頭再看看Horovod是如何實現的。
我們之前對於 Horovod 的分析和示例大多以 TensorFlow 為例。大家對各種框架如何在Horovod之中適配的總體邏輯和思路應該有了一個大致的認識,所以我們本部分主要看看一些PyTorch 相關的特殊之處。
使用PyTorch做切入的另外一個原因是:在恢復訓練這個流程上,PyTorch相關部分確實相對清晰明確。
在 horovod/torch/elastic/ 目錄下,有兩個文件 :state.py 和 sampler.py。既然是彈性相關,所以我們先來看看其特殊之處。
0x02 Sampler
在 horovod/torch/elastic/sampler.py 之中,有一個 ElasticSampler 類,我們看看具體針對彈性做了哪些處理。
因為 ElasticSampler 類之中注明,它的實現非常類似DistributedSampler,也就是 PyTorch 原生的實現,所以我們要先看看 DistributedSampler。
2.1 PyTorch Distributed Optimizer
2.1.1 定義
DistributedSampler代碼位於:torch/distributed/optim/optimizer.py。
總結一下DistributedSampler的分配方法是:每段連續的 num_replicas 個數據被拆成一個一個,分給 num_replicas 個進程,這樣就達到了不重疊不交叉的目的,但也要注意的是:這樣每個進程拿到的數據是不連續的。
__iter__ 代碼的一個技術細節是 本worker如何遍歷?
indices = indices[self.rank:self.total_size:self.num_replicas]
這里,num_replicas 實際就是rank的總數,起始位置是self.rank,結束位置是總數據長度,按照num_replicas(就是world size)作為步長來遞增,所以這里每個worker就會嚴格返回自己rank對應的那部分數據序號。
我們用一個例子來看看,比如:
a = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
print(a[0:15:3])
print(a[1:15:3])
print(a[2:15:3])
得到:
[1, 4, 7, 10, 13]
[2, 5, 8, 11, 14]
[3, 6, 9, 12, 15]
具體代碼如下:
class DistributedSampler(Sampler[T_co]):
def __iter__(self) -> Iterator[T_co]:
if self.shuffle: # 如果需要shuffle,則會基於epoch和seed進行處理
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else: # 否則直接返回數據集長度序列
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
# 是否需要補齊數據
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
# 依據自己的rank,依次返回自己的數據序號
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices) # 后續就使用這些indices來對數據進行提取
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
2.1.2 問題點
DistributedSampler 如果直接用到 彈性訓練,是有一定問題的,讓我們分析一下,有幾個問題:
- 如果用戶已經訓練了5輪,那么就意味着已經使用了前面5個批次的數據。假設此時加入了新的worker節點,那么就應該恢復訓練。那么對於已經使用過的前面 5 個批次的數據,按說就不應該再次被用來訓練了。
- 問題1: 恢復訓練之后,應該怎么去除已經處理的數據index?
- 如果加入或者減少節點,如果告訴 Sampler,我們需要更改提取規則,最起碼,num_replicas 需要被更新,以后按照新的 num_replicas 進行提取,比如原來5個節點,num_replicas = 5,現在6個節點,num_replicas 應該為 6。
- 問題2: 恢復訓練之后,何時調用
__iter__以進行新的訓練? - 問題3: 恢復訓練之后,何時修改 num_replicas?
- 問題2: 恢復訓練之后,何時調用
我們看看 DistributedSampler 就會發現,其__iter__之中,沒有任何保存狀態的相關信息。即如果重新開始訓練,依然會從全體數據中提取,而非從剩余數據中提取。也沒有發現對后面兩個問題的解決辦法。
因此,很難利用 DistributedSampler進行彈性訓練,所以 Horovod 就使用 ElasticSampler 來解決這個問題。
2.2 ElasticSampler
2.2.1 定義
從注釋中我們可以看到,ElasticSampler 自稱與 DistributedSampler 非常類似。我們隨后針對兩個類代碼比較可以看到,功能基本一致。
但是有兩個新加入的變量值得注意,即:
self.processed_indices = set()
self.remaining_indices = []
定義如下:
import math
import random
import torch.utils.data.distributed
from horovod.torch.mpi_ops import rank, size
class ElasticSampler(torch.utils.data.Sampler):
"""Sampler that partitions dataset across ranks and repartitions after reset events.
Works similar to `DistributedSampler`, but with an optional capability to record
which dataset indices have been processed each batch. When tracked by a `TorchState`
object, the sampler will automatically repartition the unprocessed indices among the
new set of workers.
In order to use this object successfully it is recommended that the user:
1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices.
Args:
dataset: Dataset used for sampling (assumed to be of constant size).
shuffle: If `True` (default), shuffle the indices.
seed: Random seed used to shuffle the sampler when `shuffle=True`.
This number should be identical across all ranks (default: 0).
"""
def __init__(self, dataset, shuffle=True, seed=0):
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed
self.epoch = 0
self.processed_indices = set() # 新加入的特色成員變量
self.num_replicas = 0
self.rank = 0
self.remaining_indices = [] # 新加入的特色成員變量
self.num_samples = 0
self.total_size = 0
self.reset()
2.2.2 彈性方案
具體彈性方案就圍繞之前提到的兩個變量來進行。
2.2.2.1 常規流程
我們回憶其注釋中提到的如何使用:
1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices.
我們可以推導出來其內在邏輯:
- 進行本 epoch 訓練。
- 當使用
__iter__獲取下一批次數據時候,self.indices = self.remaining_indices[:]就會 只從未訓練的數據里面提取。 - 每處理一個批次數據 之后,用戶使用
record_batch或者record_indices來把已經訓練完的數據批次信息保存在processed_indices。這樣就記錄了已經訓練完的數據。 - 如果產生了問題,或者有節點變更,則:
- 會調用 reset 函數,reset 會把已經訓練完的數據
processed_indices從總數據中移除,剩下的self.remaining_indice就是沒有訓練的數據。 - 恢復訓練, 只從未訓練的數據里面提取。
- 會調用 reset 函數,reset 會把已經訓練完的數據
- 當使用
- 當完成這個epoch 之后,會調用
set_epoch來重置processed_indices,也會調用 reset 方法進行清零。
具體功能代碼是:
def set_epoch(self, epoch):
"""Sets the epoch for this sampler.
When `shuffle=True`, this ensures all replicas use a different random ordering
for each epoch.
Will clear and reset the `processed_indices` for the next epoch. It is important
that this is called at the end of the epoch (not the beginning) to ensure that
partially completed epochs do not reprocess samples.
Args:
epoch: Epoch number.
"""
self.epoch = epoch
# 這里也許有網友會有疑問,就是下面兩行代碼應該交換一下次序。
# 但是實際上是沒有問題的,因為 reset 其實在異常處理時候的作用更大,在這里其實就是個清零作用。
self.processed_indices = set()
self.reset()
def record_batch(self, batch_idx, batch_size):
"""Record indices at batch `batch_idx` with length `batch_size` as processed."""
indices = set(self.get_indices(batch_idx, batch_size))
self.record_indices(indices)
def record_indices(self, indices):
"""Record set `indices` as processed."""
self.processed_indices.update(indices) # 記錄已經訓練完的數據
def get_indices(self, batch_idx, batch_size):
"""Return list of indices at batch `batch_idx` with length `batch_size`."""
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(self.indices))
return self.indices[start_idx:end_idx]
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.processed_indices = state_dict['processed_indices'] # 從保存的數據中提取
self.reset()
def state_dict(self):
return dict( # 這里是為了State.save 時候調用,就是模型保存時候,需要保存這兩個變量
epoch=self.epoch,
processed_indices=self.processed_indices
)
def reset(self):
# size 代碼位於horovod/torch/mpi_ops.py,是 size = _basics.size,可以認為就是 hvd.size()
self.num_replicas = size() # 重新配置有幾個worker
self.rank = rank()
# Exclude any samples we have already processed this epoch
# 把已經訓練完的數據移除,得到的數據 remaining_indices 都是沒有經過訓練的
self.remaining_indices = [idx for idx in range(len(self.dataset))
if idx not in self.processed_indices]
self.num_samples = int(math.ceil(len(self.remaining_indices) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
self.indices = self.remaining_indices[:] # 從剩余數據中提取
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(self.indices)
# add extra samples to make it evenly divisible
self.indices += self.indices[:(self.total_size - len(self.indices))]
assert len(self.indices) == self.total_size
# subsample
# 本worker如何遍歷?起始index是self.rank,終止index是總數據長度,按照num_replicas來遞增
self.indices = self.indices[self.rank:self.total_size:self.num_replicas]
assert len(self.indices) == self.num_samples
# 后續就按照上面的遍歷邏輯來遍歷
return iter(self.indices)
def __len__(self):
return self.num_samples
2.2.2.2 異常處理
在 horovod/torch/elastic/state.py 之中,當重新訓練時候,會調用到 ElasticSampler 的 load_state_dict 方法。
而 load_state_dict 之中,會調用 reset,這樣就把已經訓練完的數據移除,得到的數據都是沒有經過訓練的。
所以重新訓練時候,本epoch之內,不會用已經訓練的數據再次重復訓練。
我們后續會詳細分析這個流程。
2.2.1 如何使用
ElasticSampler 的使用如下,代碼位於:examples/elastic/pytorch/pytorch_imagenet_resnet50_elastic.py。
本節我們主要介紹如何使用,就是正常使用/處理流程,后續會介紹異常處理,這里省略部分次要代碼。
2.2.1.1 主體代碼
主體代碼主要注意就是使用ElasticSampler分別配置了兩個彈性采樣器。
if __name__ == '__main__':
allreduce_batch_size = args.batch_size * args.batches_per_allreduce
# Elastic Horovod: use ElasticSampler to partition data among workers.
train_dataset = datasets.ImageFolder()
train_sampler = hvd.elastic.ElasticSampler(train_dataset) # 配置了彈性采樣
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=allreduce_batch_size,
sampler=train_sampler,
**kwargs)
val_dataset = datasets.ImageFolder()
val_sampler = hvd.elastic.ElasticSampler(val_dataset) # 配置了彈性采樣
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.val_batch_size,
sampler=val_sampler,
**kwargs)
# Set up standard ResNet-50 model.
model = models.resnet50()
# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
lr_scaler),
momentum=args.momentum, weight_decay=args.wd)
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)
# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
state = hvd.elastic.TorchState(model=model,
optimizer=optimizer,
train_sampler=train_sampler,
val_sampler=val_sampler,
epoch=resume_from_epoch,
batch=0)
full_train(state)
2.2.1.2 訓練代碼
以下代碼是具體訓練代碼。
def train(state):
model.train()
epoch = state.epoch
batch_offset = state.batch
with tqdm(total=len(train_loader),
desc='Train Epoch #{}'.format(epoch + 1),
disable=not verbose) as t:
# 循環獲取數據,會間接調用到 ElasticSampler 的 __iter__ 方法來獲取數據 index
for idx, (data, target) in enumerate(train_loader):
# Elastic Horovod: update the current batch index this epoch
# and commit / check for host updates. Do not check hosts when
# we commit as it would be redundant.
state.batch = batch_idx = batch_offset + idx
if args.batches_per_commit > 0 and \
state.batch % args.batches_per_commit == 0:
state.commit()
elif args.batches_per_host_check > 0 and \
state.batch % args.batches_per_host_check == 0:
state.check_host_updates()
adjust_learning_rate(epoch, batch_idx)
optimizer.zero_grad()
# Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss)
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward()
# Elastic Horovod: record which samples were processed this batch
# so we do not reprocess them if a reset event occurs
# 這里會記錄已經完成的數據
state.train_sampler.record_batch(idx, allreduce_batch_size)
# Gradient is applied across all ranks
optimizer.step()
state.commit()
def end_epoch(state):
state.epoch += 1
state.batch = 0
state.train_sampler.set_epoch(state.epoch) # 這里會對剩余數據信息清零
state.commit()
@hvd.elastic.run
def full_train(state):
while state.epoch < args.epochs:
train(state)
validate(state.epoch)
save_checkpoint(state.epoch)
end_epoch(state) # 這里會對剩余數據信息清零
某一個epoch具體邏輯(正常處理)如下:
- 如果是最初運行,則調用reset進行初始化,其中會依據 dataset 長度構建一個 index list。用這個index list 減去 processed_indices ,就得到了本次epoch應該處理的數據 index,賦值給 remaining_indices,就是剩下來應該處理的數據index;
- 在
__iter__函數中,調用self.indices = self.remaining_indices[:],這樣 indices 就可以用來做迭代提取; - 訓練函數中,調用 iter(indices) 進行迭代提取,然后調用 record_indices 把本次使用過的index 更新到 processed_indices 之中。processed_indices 就記錄了目前使用的所有index;
- epoch 結束之后,調用 set_epoch 進行重置,即給 processed_indices 清零,調用 reset 重置 remaining_indices;
+---------------------------------------------------------------+
| ElasticSampler |
| |
+--------------------------------------------> + |
4 | set_epoch | | |
| | | |
| | 1 | reset |
| | | |
| | | |
| | v |
| | |
| | remaining_indices = dataset - processed_indices |
| | |
| | + |
| | | |
| | | |
| | 2 | __iter_ |
| | | |
| | | |
| | v |
| | indices = remaining_indices[:] |
| | + |
| | | |
| +---------------------------------------------------------------+
| |
| 3 |
| |
| v
| +--------------------------------------+------------------------------------+
| | train() train loop |
| | |
| | ----------------------------> iter(indices)+--------------------> |
| | ^ | |
| | | | |
| | step() backward() |
| | | +----------------------------------------+ | |
| | | |record_indices | | |
| | | | | | |
| | <-------------+ processed_indices.update(indices) +------+ v |
| | | | |
| | +----------------------------------------+ |
| | |
| +---------------------------------------+-----------------------------------+
| |
| |
+-----------------------------------------------+
0x03 保存和定期檢查
3.1 定期保存
Hovorod 建議用戶定周期性調用 state.commit() 來把狀態(state)備份到內存。
- 定期備份非常有用。在某些worker發生意外錯誤時,定期備份可以避免因為狀態被損壞而在重新訓練時候無法恢復現場。比如,如果一個worker剛好在更新參數過程中突然出錯,此時部分梯度更新完畢,部分梯度可能只更新到一半,這個狀態是不可逆轉而又無法繼續。因此,當此狀態發生時,會拋出一個 HorovodInternalError 異常,當 hvd.elastic.run 捕獲到這個異常后,會利用最新一次commit中恢復所有狀態。
- 因為commit狀態代價高昂(比如如參數量太大會導致耗時過長),所以需要在"每個batch的處理時間"與"如果出錯,訓練需要從多久前的狀態恢復"之間選取一個平衡點。比如,如果你每訓練10個batches就commit一次,你就把復制時間降低了10倍。但是當發生錯誤時,你需要回滾到10個batches前的狀態。
- Elastic Horowod可以通過執行我們稱之為“優雅地移除worker”操作來避免這些回滾。如果driver進程發現主機已可用或標記為刪除,它將向所有workers推送一個通知。於是在下次調用state.commit()或更輕量級的state.check_host_updates()時,一個HostsUpdatedInterrupt異常將被拋出。此異常的處理方式與“HorovodInternalError”類似,只是參數狀態不會還原到上次commit,而是從當前實時參數中恢復。
- 一般來說,如果你的硬件設施是可靠與穩定的,並且你的編排系統會在任務節點移除時提供足夠的告警,你就可低頻次調用 state.commit() 函數,同時只在每個batch結束時調用相對不耗時的 state.check_host_updates() 來檢查節點變更情況。
具體示例代碼如下:
@hvd.elastic.run
def train(state):
for state.epoch in range(state.epoch, epochs):
for state.batch in range(state.batch, batches_per_epoch):
data, target = get_random_batch()
train_one_batch(data, target)
if state.batch % batches_per_commit == 0:
state.commit() # 定期保存
state.batch = 0
3.2 異常處理
我們可以看到,HorovodInternalError 和 HostsUpdatedInterrupt 這兩個異常最大的區別:
- HorovodInternalError 異常:當 hvd.elastic.run 捕獲到這個異常后,會利用最新一次commit中恢復所有狀態。
- HostsUpdatedInterrupt 異常:處理方式與“HorovodInternalError”類似,只是參數狀態不會還原到上次commit,而是從當前實時參數中恢復。
之所以要強調這個,因為后面就要介紹如何做到不同恢復。
3.3 Commit
在用戶調用 State.commit 的時候,有兩個動作:一個是保存狀態。一個是調用 check_host_updates 檢查更新。
class State(object):
"""State representation used for tracking in memory state across workers."""
def commit(self):
self.save()
self.check_host_updates()
這里 save 就會調用到 State 的 save 操作,結合本文,就是下面要介紹的 TorchState 的 save 操作。
另外,check_host_updates 會拋出HostsUpdatedInterrupt異常。HostsUpdatedInterrupt 異常里面,是否需要 sync,從下面 check_host_updates 代碼可以看出來,就是如果節點數目有變化了,就需要sync。HostUpdateResult.removed 數值為1,這里其實可以改進,HostUpdateResult.removed 在目前這個情況之下,設定過細了。
class HostUpdateResult(IntFlag):
no_update = 0
removed = 1
added = 2
mixed = removed | added
def check_host_updates(self):
"""Checks that a notification has been sent indicating that hosts can be added or will be removed.
Raises a `HostsUpdatedInterrupt` if such a notification has been received.
"""
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
# In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
# 在這里設定,其實含義就是:如果節點有變化,就設置為True,需要同步
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed) # 拋出異常
0x04 State
我們接下來介紹異常處理邏輯,具體圍繞着 State 來介紹。對於State,我們先回憶一下其在恢復訓練時候的邏輯。
4.1 恢復訓練
重新訓練時候,會拋出兩種異常:
- 如果是 ring allreduce 相關,就轉為拋出異常 HorovodInternalError(e)。
- 如果當驅動進程通過節點發現腳本發現一個節點被標記為新增或者移除時,會拋出異常 HostsUpdatedInterrupt。
然后會進行如下處理:
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync() # 進行同步
try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore() # 進行恢復訓練
skip_sync = False # 需要同步
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync # 記錄是否需要同步
reset()
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper
邏輯如下:
+------------------------------------------------------------------------------+
| Worker |
| |
| +------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError | | |
| | + | | |
| | + | | | |
| | | | | | |
| | | v | | |
| | | state.restore() | | |
| | | + | | |
| | | | | | |
| | +------------------+ <------------------+ | | |
| | | | | | |
| | | | | | |
| | v v | | |
| | reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +-----------------------------------> | |
| | | |
| +------------------------------------------------------------------------+ |
| |
+------------------------------------------------------------------------------+
因為這里涉及了大量的state操作,所以我們接下來要看看 TorchState:
4.2 TorchState
首先,我們要看看 TorchState 如何使用。當調用時候,使用如下方法來生成一個TorchState:
state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset]) # 注冊用戶定義的方法 on_state_reset
train(state)
其次,我們看看 TorchState 的定義,這里的 sync,restore,reset方法就在恢復訓練中被調用。
在初始化函數 __init__ 之中,會設置 handler,以我們的調用為例,就是 train_sampler,val_sampler這兩個對應的sampler會配置對應的handler,即SamplerStateHandler。
TorchState 繼承了 ObjectState,ObjectState 繼承了 State,所以前面提到的 commit 代碼中的 self.save(),就會調用到TorchState.save,而這里又會調用到 SamplerStateHandler.save。
class TorchState(ObjectState):
"""State representation of a PyTorch training process.
Multiple models and optimizers are supported by providing them as
kwargs. During initialization, `TorchState` will assign attributes
for every keyword argument, and handle its state synchronization.
Args:
model: Optional PyTorch model.
optimizer: Optional PyTorch optimizer.
kwargs: Attributes sync, will be exposed as attributes of the object. If a handler exists
for the attribute type, it will be used to sync the object, otherwise it will be
handled an ordinary Python object.
"""
def __init__(self, model=None, optimizer=None, **kwargs):
kwargs.update(dict(model=model, optimizer=optimizer))
# 這里會設置 handler,以我們的調用為例,就是train_sampler,val_sampler這兩個對應的sampler會配置對應的handler
self._handlers, kwargs = _get_handlers(kwargs)
for name, handler in self._handlers.items():
setattr(self, name, handler.value)
super(TorchState, self).__init__(bcast_object=broadcast_object,
get_rank=rank,
**kwargs)
def save(self):
for handler in self._handlers.values():
handler.save() # 調用到save,針對我們,就是調用到了SamplerStateHandler的save
super(TorchState, self).save()
def restore(self):
# 會進行恢復狀態
for handler in self._handlers.values():
handler.restore() # 這里會調用到sampler的restore方法。
super(TorchState, self).restore()
def sync(self):
# 會進行同步狀態
for handler in self._handlers.values():
handler.sync() # 這里會調用到sampler的sync方法。
super(TorchState, self).sync()
def __setattr__(self, name, value):
if hasattr(self, name) and name in self._handlers:
self._handlers[name].set_value(value)
super().__setattr__(name, value)
基類代碼中有:
class State(object):
def on_reset(self):
self._host_messages = queue.Queue()
self.reset() # 調用到reset
for callback in self._reset_callbacks:
callback()
4.3 設置 handler
上節中,我們可以看到,無論是reset,還是restore,都會調用到 _handlers 來進行處理,所以我們需要進一步分析。
首先就是如何設置handler。具體參見如下代碼,主要是通過一個全局配置 _handler_registry 來指定哪個 handler 處理哪種類型實例,比如這里有 (ElasticSampler, SamplerStateHandler),就代表着 SamplerStateHandler 是用來處理 ElasticSampler的 handler。
_handler_registry = [
(torch.nn.Module, ModelStateHandler),
(torch.optim.Optimizer, OptimizerStateHandler),
(ElasticSampler, SamplerStateHandler), # SamplerStateHandler 是用來處理 ElasticSampler的
]
def get_handler_registry():
return _handler_registry
def set_handler_registry(registry):
global _handler_registry
_handler_registry = registry
def _get_handler(v):
# 依據我們的樣例代碼,v是 train_sampler,而 train_sampler,val_sampler就是 ElasticSampler 的實例,所以得到 handler_type是 ElasticSampler,則會構建一個 SamplerStateHandler 並且返回
for handler_type, handler_cls in _handler_registry:
if isinstance(v, handler_type):
return handler_cls(v) # 調用 SamplerStateHandler(train_sampler) 生成實例
return None
def _get_handlers(kwargs):
handlers = {}
remainder = {}
# 這里k,v就是 train_sampler=train_sampler,所以 k 是 "train_sampler", v是實例 train_sampler
for k, v in kwargs.items():
handler = _get_handler(v)
if handler:
handlers[k] = handler
else:
remainder[k] = v
return handlers, remainder
4.4 SamplerStateHandler
既然知道了 ElasticSampler 由 SamplerStaeHandler 處理,就來分析一下 SamplerStateHandler。
初始化之后,self.value 就是 sampler,針對我們之前的分析,就是ElasticSampler。
SamplerStateHandler 具體代碼是,這里需要注意的是:初始化時候,會把ElasticSampler的狀態保存起來,以后如果出錯,會用此來恢復。
同時,save 也會被調用,用來恢復,我們馬上就會分析。
class SamplerStateHandler(StateHandler):
def __init__(self, sampler):
super().__init__(sampler)
# 這里會保存 ElasticSampler 的屬性和數據
self._saved_sampler_state = copy.deepcopy(self.value.state_dict())
def save(self):
# 保存 ElasticSampler 的屬性和數據
self._saved_sampler_state = copy.deepcopy(self.value.state_dict())
def restore(self):
# load_state_dict 會用__init__ 之中保存的原始數據來恢復,最終會調用到 ElasticSampler.reset 方法
self.value.load_state_dict(self._saved_sampler_state)
def sync(self):
# 1)Get the set of processed indices from all workers
world_processed_indices = _union(allgather_object(self.value.processed_indices))
# 2) Replace local processed indices with global indices
state_dict = self.value.state_dict() # 這里會調用到 ElasticSampler 的 state_dict 方法
state_dict['processed_indices'] = world_processed_indices
# 3) Broadcast and load the state to make sure we're all in sync
# 注意,這里的 load_state_dict 最終也會調用一次 reset
self.value.load_state_dict(broadcast_object(state_dict))
SamplerStateHandler 的 基類是:
class StateHandler(object):
def __init__(self, value):
self.value = value
def save(self):
raise NotImplementedError()
def restore(self):
raise NotImplementedError()
def sync(self):
raise NotImplementedError()
def set_value(self, value):
self.value = value
self.save()
4.5 保存
我們拓展一下save相關操作序列。
TorchState 繼承了 ObjectState,ObjectState 繼承了 State,所以:
- 前面提到的 commit 代碼中的 self.save(),就會調用到TorchState.save。
- 而TorchState.save又會調用到 SamplerStateHandler.save。
- SamplerStateHandler.save 會保存 ElasticSampler 的屬性和數據,就是保存了 ElasticSampler 的 epoch 和 processed_indices。
這樣,在定期 commit 的時候,就定期保存了模型的狀態和 ElasticSampler 的狀態,這些會在恢復訓練中用到。具體下圖所示:
+---------------------------+
| TorchState |
| |
| commit |
| + |
| | |
| | 1 |
| | |
| v |
| save |
| | |
| | |
+---------------------------+
|
| 2
|
|
+-----------------------------------------------------------------+
|SamplerStateHandler | |
| | |
| | |
| | |
| | |
| def save(self): v |
| |
| _saved_sampler_state = copy.deepcopy( value.state_dict() ) |
| + |
| | |
+-----------------------------------------------------------------+
|
|
| 3
|
|
+------------------------------------------+
| ElasticSampler | |
| | |
| | |
| | |
| def state_dict(self): | |
| return dict( v |
| self.epoch, |
| self.processed_indices |
| ) |
| |
+------------------------------------------+
只看靜態定義,還是很難理解,需要分析動態流程。因為有兩種異常,所以我們分開剖析。
回憶一下兩個異常最大的區別:
- HorovodInternalError 異常:當 hvd.elastic.run 捕獲到這個異常后,會利用最新一次commit中恢復所有狀態。
- HostsUpdatedInterrupt 異常:處理方式與“HorovodInternalError”類似,只是參數狀態不會還原到上次commit,而是從當前實時參數中恢復。
4.6 HostsUpdatedInterrupt
如果當驅動進程通過節點發現腳本發現一個節點被標記為新增或者移除時,會拋出異常 HostsUpdatedInterrupt。此時不是關鍵異常,因此可以繼續訓練本epoch,只是從后續訓練數據中,移除本epoch已經處理的數據。因此可以做到 參數狀態不會還原到上次commit,而是從當前實時參數中恢復。
下面代碼之中,我們只保留 HostsUpdatedInterrupt 相關代碼。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync() # 3) 進行同步
try:
return func(state, *args, **kwargs) # 這里會出錯,而且重新訓練也是來到這里
except HostsUpdatedInterrupt as e:
# 1) 進行異常處理
skip_sync = e.skip_sync # 2.1) 記錄是否需要同步
reset() # 2)這里會調用_basics.init 重新初始化 horovod,間接設定了ElasticSampler之中的 num_replicas
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper
發生異常之后,
- 1)HostsUpdatedInterrupt 表示本 epoch 需要繼續訓練,所以進行異常處理,其中只是會:
- 1.1) 記錄本異常處理是否需要同步 :skip_sync = e.skip_sync。
- 2)這個步驟主要是重啟 hvd,對worker數目進行更改。具體是調用 State 自身的 reset() 方法(代碼位於
horovod/torch/elastic/__init__.py),其中會:- 2.1) 調用 shutdown() 來結束本次任務。
- 2.2) 調用 init(),從而調用_basics.init,最終重新建立 MPI 相關 context,所以 hvd.size() 就根據最新的worker數目進行了更改。后續
ElasticSampler.__iter__之中會相應修改num_replicas。
- 3)這個步驟是把已經訓練完的數據移除,得到的數據都是沒有經過訓練的。如果需要同步,則會調用 state.sync() ,其會調用 SamplerStateHandler.sync 方法,其內部會:
- 3.1) SamplerStateHandler會利用集合通信從所有worker中收集processed_indices,賦予給 world_processed_indices,這就是所有workers 已經處理過的數據 index。
- 3.2) 調用 ElasticSampler.state_dict方法,得到本地 ElasticSampler.epoch 和 ElasticSampler.processed_indices 的引用。然后將 world_processed_indices 賦值給 state_dict['processed_indices'],這樣,本地 ElasticSampler.processed_indices 就是所有workers 已經處理過的數據 index。
- 3.3)
self.value.load_state_dict(broadcast_object(state_dict))有兩步操作:- 廣播,這樣在同步之后,所有worker都有同樣的 state_dict['processed_indices'] 數據了。
- load_state_dict 會再調用一次 ElasticSampler.reset,此次 reset 會更改
num_replicas,也會從總數據中去除processed_indices,得到新的remaining_indices, 從而 后續__iter__之中,就會相應對提取index 的策略進行相應更改。
- 4)所以這樣就把已經訓練完的數據移除,所以得到的 remaining_indices 數據都是沒有經過訓練的。所以重新訓練時候,本epoch之內,不會用已經訓練的數據再次重復訓練,而是從當前實時參數中恢復。
- 重新訓練會調用 return func(state, *args, **kwargs) 進行訓練,這里會處理
ElasticSampler.__iter__。 - 當使用
__iter__獲取下一批次數據時候,self.indices = self.remaining_indices[:]就會 只從未訓練的數據里面提取。
- 重新訓練會調用 return func(state, *args, **kwargs) 進行訓練,這里會處理
具體邏輯如下:
+-----------------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +-----------------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +-----------------------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v 3) | | |
| | state.sync() +------------------------------------------+----------------------+ | | |
| | | | | | |
| | + | | | | |
| | | | | | | |
| | | | | | | |
| | v | | | | |
| | +------------------+---------------+ 3.1) | 3.2) | | | |
| | | train | | | | | |
| | | | | | | | |
| | | optimizer.apply_gradients +---------+ | | | | |
| | | + | v | | | |
| | +-------+ state.commit() | | | | |
| | | | + | ElasticSampler.load_state_dict | | | |
| | | +----------------------------------+ | + | | | |
| | | | | | | | |
| | v v | | | | |
| | HostsUpdatedInterrupt HorovodInternalError v | | | |
| | + ElasticSampler.reset | | | |
| | + | + | | | |
| | | | | | | | |
| | | 1) v | | | | |
| | | state.restore() v | | | |
| | | + +-----------+-----------------+ | | | |
| | | | | ElasticSampler | | | | |
| | +------------------+ <------------------+ | | | | | |
| | | | | remaining_indices | | | | |
| | | | | | | | | |
| | v v | num_samples | | | | |
| | reset() | | | | | |
| | 2) | total_size | | | | |
| | state.on_reset() | | | | | |
| | | epoch | | | | |
| | + | | | | | |
| | | | processed_indices | | | | |
| | | | | | | | |
| | | | state_dict <-------------+ | | |
| | | | | | | |
| | | +-----------------------------+ | | |
| | | | | |
| | +------------------------------------------------------------------------------^ | |
| | | |
| +-----------------------------------------------------------------------------------------------------------------+ |
| |
+-----------------------------------------------------------------------------------------------------------------------+
手機如下:

4.7 HorovodInternalError
如果是 ring allreduce 相關,就轉為拋出異常 HorovodInternalError(e)。HorovodInternalError 是關鍵異常,此時本 epoch 現有狀態其實意義不大,應該利用最新一次commit中恢復所有狀態。
下面代碼之中,我們只保留 HorovodInternalError 相關代碼。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync() # 3) 進行同步
try:
return func(state, *args, **kwargs) # 這里會出錯,而且重新訓練也是來到這里
except HorovodInternalError:
# 1) 進行異常處理
state.restore() #1.1) 進行恢復訓練,這里是和 HostsUpdatedInterrupt 的不同之處
skip_sync = False # 1.2) 記錄需要同步
reset() # 2)這里會調用_basics.init 重新初始化 horovod,間接設定了ElasticSampler之中的 num_replicas
state.on_reset() # 進行重啟
finally:
notification_manager.remove_listener(state)
return wrapper
HorovodInternalError 和 HostsUpdatedInterrupt 的代碼路徑幾乎一樣,只是多了一步 state.restore() 。
這里為啥也要走查看節點變化這個代碼路徑呢?因為Horovod是定期檢查節點變化,所以可能產生HorovodInternalError時候,也有節點變化了,只是還沒有發現而已,所以可以一並處理了。
具體邏輯為:
- 1)HorovodInternalError 表示本 epoch 需要恢復訓練,所以先進行異常處理:
- 1.1)state.restore() 會調用 SamplerStateHandler.restore(這里是與HostsUpdatedInterrupt處理差異之處)。
- 進而調用 ElasticSampler.load_state_dict方法,會用在
SamplerStateHandler.__init__或者SamplerStateHandler.save之中原始保存的數據來恢復 ElasticSampler。保存的數據就是 processed_indices 和 epoch。 - ElasticSampler.load_state_dict方法 進而會調用 ElasticSampler.reset方法,使用 processed_indices 把已經訓練完的數據移除,最新得到的 remaining_indices 數據都是沒有經過訓練的(針對上次保存的 processed_indices 來說)。
- 進而調用 ElasticSampler.load_state_dict方法,會用在
- 1.2) 記錄本異常處理需要同步 : skip_sync = False。
- 1.1)state.restore() 會調用 SamplerStateHandler.restore(這里是與HostsUpdatedInterrupt處理差異之處)。
- 2)這個步驟主要是重啟 hvd。調用 State 自身的 reset() 方法(代碼位於
horovod/torch/elastic/__init__.py),其中會:- 2.1) 調用 shutdown() 來結束本次任務。
- 2.2) 調用 init(),從而調用_basics.init,最終重新建立 MPI 相關 context。
- 3)這個步驟是把已經訓練完的數據移除,得到的數據都是沒有經過訓練的。因為這里需要同步,所以會調用 state.sync() ,其會調用 SamplerStateHandler.sync 方法,其內部會:
- 3.1) SamplerStateHandler會利用集合通信從所有worker中收集processed_indices,賦予給 world_processed_indices,這就是所有workers 已經處理過的數據 index。需要注意的是:因為是使用在
__init__或者save之中原始保存的數據來恢復,所以其實這一步是恢復到上次commit狀態。 - 3.2) 調用 ElasticSampler.state_dict方法,得到本地 ElasticSampler.epoch 和 ElasticSampler.processed_indices 的引用。然后將 world_processed_indices 賦值給 state_dict['processed_indices'],這樣,本地 ElasticSampler.processed_indices 就是所有workers 已經處理過的數據 index。
- 3.3) 這里
self.value.load_state_dict(broadcast_object(state_dict))有兩步操作:- 廣播,這樣在同步之后,所有worker都有同樣的 state_dict['processed_indices'] 數據了。
- load_state_dict 會再調用一次 ElasticSampler.reset,此次 reset 會更改
num_replicas,也會從總數據中去除processed_indices,得到新的remaining_indices, 從而 后續__iter__之中,就會相應對提取index 的策略進行相應更改。
- 3.1) SamplerStateHandler會利用集合通信從所有worker中收集processed_indices,賦予給 world_processed_indices,這就是所有workers 已經處理過的數據 index。需要注意的是:因為是使用在
- 4)這樣就是恢復到epoch 上次 commit 的狀態進行訓練。
- 重新訓練會調用 return func(state, *args, **kwargs) 進行訓練,這里會處理
ElasticSampler.__iter__。 - 當使用
__iter__獲取下一批次數據時候,self.indices = self.remaining_indices[:]就會 只從未訓練的數據里面提取。
- 重新訓練會調用 return func(state, *args, **kwargs) 進行訓練,這里會處理
具體邏輯如下圖:
+--------------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +--------------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +-----------------------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v 3 | | |
| | state.sync() +-----------------------------------------------------------------+ | | |
| | | | | |
| | + +--------------+ | | | |
| | | | | | | | |
| | | | | | | | |
| | v | v | | | |
| | +------------------+---------------+ | | | | |
| | | train | | SamplerStateHandler.restore | | | |
| | | | | + | | | |
| | | optimizer.apply_gradients +---------+ | | | | | |
| | | + | | | | | | |
| | +-------+ state.commit() | | v | | | |
| | | | + | | ElasticSampler.load_state_dict | | | |
| | | +----------------------------------+ | | + | | | |
| | | | | | | | | |
| | v v | | | | | |
| | HostsUpdatedInterrupt HorovodInternalError | v | | | |
| | + | ElasticSampler.reset | | | |
| | + | | + | | | |
| | | | | | | | | |
| | | v 1 | | | | | |
| | | state.restore()+-----+ v | | | |
| | | + +-----------+-----------------+ | | | |
| | | | | ElasticSampler | | | | |
| | +------------------+ <------------------+ | | | | | |
| | | | | remaining_indices | | | | |
| | | | | | | | | |
| | v v | num_samples | | | | |
| | reset() 2 | | | | | |
| | | total_size | | | | |
| | state.on_reset() | | | | | |
| | | epoch | | | | |
| | + | | | | | |
| | | | processed_indices | | | | |
| | | | | | | | |
| | | | state_dict <-------------+ | | |
| | | | | | | |
| | | +-----------------------------+ | | |
| | | | | |
| | +------------------------------------------------------------------------------^ | |
| | | |
| +--------------------------------------------------------------------------------------------------------------+ |
| |
+--------------------------------------------------------------------------------------------------------------------+
手機如下:

4.8 ElasticSampler.__iter__
到目前為止,我們還有一個問題沒有仔細分析,就是何時調用 ElasticSampler.__iter__
我們仔細梳理一下:
以下是彈性訓練總體邏輯:
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync()
try:
# 如果出錯恢復,這里會繼續調用 func 進行訓練
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore()
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper
彈性邏輯使用注解來封裝了full_train,所以 func 就是 full_train。
@hvd.elastic.run
def full_train(state):
while state.epoch < args.epochs:
train(state)
validate(state.epoch)
save_checkpoint(state.epoch)
end_epoch(state)
我們看看 train 的主要代碼:
def train(state):
model.train()
epoch = state.epoch
with tqdm(...) as t:
# 這里 enumerate 之中會調用到 ElasticSampler.__iter__
for idx, (data, target) in enumerate(train_loader):
# Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss)
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward()
# Elastic Horovod: record which samples were processed this batch
# so we do not reprocess them if a reset event occurs
state.train_sampler.record_batch(idx, allreduce_batch_size)
# Gradient is applied across all ranks
optimizer.step()
state.commit()
所以我們可以理出來總體邏輯:
當出錯恢復時候,train 會再次被調用,調用時候就會使用 enumerate(train_loader)調用到 ElasticSampler.__iter__。
num_replicas 在之前 reset 時候已經被設置,所以此時就是根據新的 world size 和 remaining_indices 重新確定提取數據的策略。
def __iter__(self):
self.indices = self.remaining_indices[:] # 從剩余數據中提取
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(self.indices)
# add extra samples to make it evenly divisible
self.indices += self.indices[:(self.total_size - len(self.indices))]
assert len(self.indices) == self.total_size
# subsample
# 本worker如何遍歷?起始index是self.rank,終止index是總數據長度,按照 num_replicas 來遞增
self.indices = self.indices[self.rank:self.total_size:self.num_replicas]
assert len(self.indices) == self.num_samples
# 后續就按照上面的遍歷邏輯來遍歷
return iter(self.indices)
具體邏輯如下,其中
1)在 reset 之中設置了num_replicas。
2)在 ElasticSampler.__iter__ 之中根據新的 world size 和 remaining_indices 重新確定提取數據的策略。
+----------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +----------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +--------------------------------+ +------------------+---------------+ | | |
| | | ElasticSampler | | train | | | |
| | | +---------------------------+ | | optimizer.apply_gradients +---------+ | | |
| | | | __iter__ | | 2) | | | | | |
| | | | | | <------------+ enumerate(train_loader) | | | | |
| | | | | | | | | | | |
| | | | remaining_indices | | +-------+ state.commit() | | | | |
| | | | | | | | | | | | |
| | | | | | | +----------------------------------+ | | | |
| | | | num_replicas | | v v | | |
| | | | | | HostsUpdatedInterrupt HorovodInternalError | | |
| | | | ^ | | + | | |
| | | | | | | + | | | |
| | | +---------------------------+ | | | | | |
| | +--------------------------------+ | v | | |
| | | | state.restore() | | |
| | | | + | | |
| | | | | | | |
| | | +------------------+ <------------------+ | | |
| | | | | | | |
| | | | | | | |
| | | 1) v v | | |
| | +----------------------------------------+ reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +-----------------------------------> | |
| | | |
| +----------------------------------------------------------------------------------------------------------+ |
| |
+----------------------------------------------------------------------------------------------------------------+
手機如下:

至此,彈性訓練如何恢復就分析完畢,以后可能結合 Pytorch 分布式 optimizer 來繼續分析。
0xFF 參考
pytorch中優化器optimizer.param_groups
