[源碼解析] 深度學習分布式訓練框架 horovod (14) --- 彈性訓練發現節點 & State
0x00 摘要
Horovod 是Uber於2017年發布的一個易於使用的高性能的分布式訓練框架,在業界得到了廣泛應用。
本系列將通過源碼分析來帶領大家了解 Horovod。本文是系列第十四篇,看看horovod 如何動態發現節點 和 狀態信息。
本系列其他文章鏈接如下:
[源碼解析] 深度學習分布式訓練框架 Horovod (1) --- 基礎知識
[源碼解析] 深度學習分布式訓練框架 horovod (2) --- 從使用者角度切入
[源碼解析] 深度學習分布式訓練框架 horovod (3) --- Horovodrun背后做了什么
[源碼解析] 深度學習分布式訓練框架 horovod (4) --- 網絡基礎 & Driver
[源碼解析] 深度學習分布式訓練框架 horovod (5) --- 融合框架
[源碼解析] 深度學習分布式訓練框架 horovod (6) --- 后台線程架構
[源碼解析] 深度學習分布式訓練框架 horovod (7) --- DistributedOptimizer
[源碼解析] 深度學習分布式訓練框架 horovod (8) --- on spark
[源碼解析] 深度學習分布式訓練框架 horovod (9) --- 啟動 on spark
[源碼解析] 深度學習分布式訓練框架 horovod (10) --- run on spark
[源碼解析] 深度學習分布式訓練框架 horovod (11) --- on spark --- GLOO 方案
[源碼解析] 深度學習分布式訓練框架 horovod (12) --- 彈性訓練總體架構
[源碼解析] 深度學習分布式訓練框架 horovod (13) --- 彈性訓練之 Driver
0x01 設計點
本文對應架構圖中的 Host Discovery 部分,因為是被 Driver Main 調用,所以把兩部分一起展示出。
發現節點機制的幾個關鍵設計點如下:
- 有節點變化時候,如何即時發現?Horovod是通過定期調用完成。
- 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己注冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
- worker得到通知之后,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之后就會調用State的對應callback函數,或者同步狀態,或者進行其他處理。
0x02 發現機制
這部分代碼主要在:horovod/runner/elastic/discovery.py。
2.1 發現腳本
HostDiscoveryScript 的主要作用就是保存腳本(程序啟動時候設置進來),然后當執行 find_available_hosts_and_slots 的時候,調用這個發現腳本,得到 host 信息。
該腳本的輸出的格式 就是調用 horovodrun 時候 的 host 參數格式,比如:
$ sh ./discover_hosts.sh # 運行腳本,輸出節點信息
10.68.32.2:4
10.68.32.3:4
10.68.32.4:4
定義如下:
class HostDiscoveryScript(HostDiscovery):
def __init__(self, discovery_script, slots):
self._discovery_script = discovery_script # 設定腳本
self._default_slots = slots # 審定slots
super(HostDiscoveryScript, self).__init__()
def find_available_hosts_and_slots(self):
stdout = io.StringIO()
# 執行發現腳本
exit_code = safe_shell_exec.execute(self._discovery_script, stdout=stdout)
# 讀取腳本輸出,解析出來host信息
host_slots = {}
lines = set(stdout.getvalue().strip().split('\n'))
for line in lines:
host = line
if ':' in line:
host, slots = line.split(':')
host_slots[host] = int(slots)
else:
host_slots[host] = self._default_slots
return host_slots
2.2 HostManager
HostManager 是 host discovery 的核心,作用是維護當前 host 以及 狀態,其主要變量是:
- self._current_hosts : 當前的 host 信息,包括 slot,assign order 等等;
- self._hosts_state :當前的 host 狀態,包括黑名單,event 等;
- self._discovery :可以認為是對 發現腳本 的一個封裝,用來動態執行 發現腳本,獲取 host 信息;
class HostManager(object):
def __init__(self, discovery):
self._current_hosts = DiscoveredHosts(host_slots={}, host_assignment_order=[])
self._hosts_state = defaultdict(HostState)
self._discovery = discovery
def update_available_hosts(self):
# TODO(travis): also check for hosts removed from the blacklist in the future
# 檢查更新,給出是添加,還是刪除節點
def check_update(cur_host_slots, prev_host_slots):
res = HostUpdateResult.no_update
for prev_h in prev_host_slots:
if prev_h not in cur_host_slots:
# prev_h is a removed host
res |= HostUpdateResult.removed
for h in cur_host_slots:
if h not in prev_host_slots:
# h is an added host
res |= HostUpdateResult.added
elif cur_host_slots[h] > prev_host_slots[h]:
# h has more slots added
res |= HostUpdateResult.added
elif cur_host_slots[h] < prev_host_slots[h]:
# h has removed some slots
res |= HostUpdateResult.removed
return res
prev_host_slots = self._current_hosts.host_slots
prev_host_assignment_order = self._current_hosts.host_assignment_order
host_slots = self._discovery.find_available_hosts_and_slots()
if prev_host_slots != host_slots: # 有修改
# 找到不在黑名單里的host
available_hosts = set([host for host in host_slots.keys() if not self._hosts_state[host].is_blacklisted()])
# 找到host的order
host_assignment_order = HostManager.order_available_hosts(available_hosts, prev_host_assignment_order)
self._current_hosts = DiscoveredHosts(host_slots=host_slots,
host_assignment_order=host_assignment_order)
# 檢查更新
return check_update(self._current_hosts.host_slots, prev_host_slots)
else: # 沒修改就不更新
return HostUpdateResult.no_update
HostManager 核心邏輯是 update_available_hosts 方法,就是用來發現可用的 host。
2.2.1 order_available_hosts
order_available_hosts 的作用是:確保最老的host被賦予最低的rank,即rank 0,因為最老的host最有可能擁有原來訓練的模型以及訓練狀態,這些信息需要在下一輪新迭代之前,發給所有worker。
@staticmethod
def order_available_hosts(available_hosts, prev_host_assignment_order):
# We need to ensure this list preserves relative order to ensure the oldest hosts are assigned lower ranks.
host_assignment_order = [host for host in prev_host_assignment_order if host in available_hosts]
known_hosts = set(host_assignment_order)
for host in available_hosts:
if host not in known_hosts:
host_assignment_order.append(host)
return host_assignment_order
2.3 配置
我們看看是發現腳本如何配置進入HostManager之中。
首先,發現腳本是在_run_elastic之中配置。
def _run_elastic(args):
# construct host discovery component
if args.host_discovery_script:
# 如果參數中有設置發現腳本,則賦值為discover_hosts
discover_hosts = discovery.HostDiscoveryScript(args.host_discovery_script, args.slots)
elif args.hosts: # 如果參數設置好了hosts,則賦值為discover_hosts
_, available_host_slots = hosts.parse_hosts_and_slots(args.hosts)
if len(available_host_slots) < 2:
raise ValueError('Cannot run in fault tolerance mode with fewer than 2 hosts.')
discover_hosts = discovery.FixedHosts(available_host_slots)
else: # 拋出異常
raise ValueError('One of --host-discovery-script, --hosts, or --hostnames must be provided')
# 配置進入setting
settings = elastic_settings.ElasticSettings(discovery=discover_hosts,
.....)
env = os.environ.copy()
config_parser.set_env_from_args(env, args)
gloo_run_elastic(settings, env, args.command)
其次,發現腳本被設置到ElasticSettings之中。
class ElasticSettings(BaseSettings):
def __init__(self, discovery, min_np, max_np, elastic_timeout, reset_limit, **kwargs):
self.discovery = discovery
當啟動時候,會設置到 ElasticDriver 之中。
def start(self):
"""Starts the Horovod driver and services."""
self.rendezvous = RendezvousServer(self.settings.verbose)
self.driver = ElasticDriver(
rendezvous=self.rendezvous,
discovery=self.settings.discovery, # 在這里設置發現腳本
min_np=self.settings.min_np,
max_np=self.settings.max_np,
timeout=self.settings.elastic_timeout,
reset_limit=self.settings.reset_limit,
verbose=self.settings.verbose)
最后,建立HostManager時候,會設置發現腳本。
class ElasticDriver(object):
def __init__(self, rendezvous, discovery, min_np, max_np, timeout=None, reset_limit=None, verbose=0):
self._rendezvous = rendezvous
self._host_manager = HostManager(discovery) # 設置腳本
0x03 如何調用
3.1 無限循環線程
HostManager 的調用邏輯是在 ElasticDriver 類中。
ElasticDriver 在初始化時候,生成一個后台線程 _discovery_thread。
self._discovery_thread = threading.Thread(target=self._discover_hosts)
3.1.1 定時探尋
在 _discovery_thread
之中,會運行_discover_hosts。
ElasticDriver._discover_hosts
會:
- 首先調用
self._host_manager.update_available_hosts(self._host_manager.current_hosts, update_res)
得到最新的host狀態; - 其次,如果新 host 狀態已經發生的變化,於是就調用 _notify_workers_host_changes 和 _wait_hosts_cond.notify_all 來通知大家有 host 變化了;
def _discover_hosts(self):
first_update = True
while not self._shutdown.is_set():
self._wait_hosts_cond.acquire()
try:
# 得到最新的host狀態
update_res = self._host_manager.update_available_hosts()
if update_res != HostUpdateResult.no_update:
self._notify_workers_host_changes(self._host_manager.current_hosts, update_res)
self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
except RuntimeError as e:
if first_update:
# Misconfiguration, fail the job immediately
self._shutdown.set()
self._wait_hosts_cond.notify_all() # 通知大家有 host 變化
raise
# Transient error, retry until timeout
logging.warning(str(e))
finally:
self._wait_hosts_cond.release()
first_update = False
self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)
邏輯如下,是一個 thread loop 定時運行:
<--------------------^
+ |
| thread loop |
| |
| +----------------+----------------------+
| | ElasticDriver._discovery_thread |
| | |
| | |
| | |
| | |
| | HostManager.update_available_hosts |
| | |
| +----------------+----------------------+
| ^
| |
v |
+-------------------->+
3.1.2 通知變化
如果發現有host 變化,則調用 self._notify_workers_host_changes
來通知。
即,當Driver的定時進程通過節點發現腳本發現某一個節點被標記為新增或者移除時,它將 調用 _notify_workers_host_changes 發送一個通知到所有workers。
邏輯如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+-----------------------------------------------+
| | ElasticDriver._discovery_thread |
| | |
| | |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v |
| | YES |
| | update_res != no_update ??? +--------+ |
| | + | |
| | | | |
| | | v |
| | | NO |
| | | _notify_workers_host_changes |
| | v |
| +----------------------------------------------------------------+
| |
| |
| |
v |
+-------------------->+
具體如下:
def _notify_workers_host_changes(self, current_hosts, update_res):
next_host_assignments = {}
if current_hosts.count_available_slots() >= self._min_np:
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)
if next_host_assignments == self.host_assignments:
# Skip notifying workers when host changes would not result in changes of host assignments
return
coordinator_slot_info = self.get_coordinator_info()
# 獲取 WorkerNotificationClient
coordinator_client = self.get_worker_client(coordinator_slot_info)
timestamp = _epoch_time_s()
coordinator_client.notify_hosts_updated(timestamp, update_res) # 通知
get_worker_client 函數就是獲取 WorkerNotificationClient,然后調用 WorkerNotificationClient 來進行通知,所以下面我們接下來看 WorkerNotificationClient。
def get_worker_client(self, slot_info):
return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
具體如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES | +---------------------------+
| | update_res != no_update ??? +-----+ | | |
| | + | | | |
| | | | | | WorkerNotificationClient |
| | | v | notify_hosts_updated | |
| | | NO | | |
| | | _notify_workers_host_changes+------------------------> | |
| | v | | |
| +-----------------------------------------------------+ +---------------------------+
| |
| |
| |
v |
+-------------------->+
手機如下:
3.2 如何通知
就是利用 WorkerNotificationClient 發送 HostsUpdatedRequest。
3.2.1 WorkerNotificationClient
可以看到,WorkerNotificationService 繼承了 network.BasicService,所以 WorkerNotificationClient 就是作為 WorkerNotificationService 的操作接口,從而給 WorkerNotificationService 發送 HostsUpdatedRequest。
class WorkerNotificationClient(network.BasicClient):
def __init__(self, addresses, key, verbose, match_intf=False):
super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
addresses,
key,
verbose,
match_intf=match_intf)
def notify_hosts_updated(self, timestamp, update_res):
self._send(HostsUpdatedRequest(timestamp, update_res))
3.2.2 WorkerNotificationService
WorkerNotificationService 會響應 HostsUpdatedRequest。
class WorkerNotificationService(network.BasicService):
NAME = 'worker notification service'
def __init__(self, key, nic, manager):
super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
key,
nic)
self._manager = manager
def _handle(self, req, client_address):
if isinstance(req, HostsUpdatedRequest):
self._manager.handle_hosts_updated(req.timestamp, req.res)
return network.AckResponse()
return super(WorkerNotificationService, self)._handle(req, client_address)
3.2.3 WorkerNotificationManager
handle_hosts_updated 會逐一通知注冊在WorkerNotificationManager 上的 listener(就是用戶代碼中的 State)。
WorkerNotificationManager 是在 horovod/common/elastic.py 構建,每一個host上運行一個。
notification_manager = WorkerNotificationManager()
具體定義如下:
class WorkerNotificationManager(object):
def __init__(self):
self._lock = threading.Lock()
self._service = None
self._listeners = set()
def init(self, rendezvous_addr=None, rendezvous_port=None,
nic=None, hostname=None, local_rank=None):
with self._lock:
if self._service:
return
rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
if not rendezvous_addr:
return
rendezvous_port = rendezvous_port if rendezvous_port is not None else \
int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
local_rank = local_rank if local_rank is not None else \
int(os.environ.get(HOROVOD_LOCAL_RANK))
secret_key = secret.make_secret_key()
self._service = WorkerNotificationService(secret_key, nic, self)
value = (self._service.addresses(), secret_key)
put_data_into_kvstore(rendezvous_addr,
rendezvous_port,
PUT_WORKER_ADDRESSES,
self._create_id(hostname, local_rank),
value)
def register_listener(self, listener):
self._listeners.add(listener)
def remove_listener(self, listener):
self._listeners.remove(listener)
def handle_hosts_updated(self, timestamp, update_res):
for listener in self._listeners:
listener.on_hosts_updated(timestamp, update_res)
3.2.4 通知 State
我們再梳理以下流程:
- 當Driver的定時進程通過節點發現腳本發現某一個節點被標記為新增或者移除時,它將發送一個通知到所有workers。
- 每一個 worker 有自己對應的 State,都被存儲於
WorkerNotificationManager . _listeners
。 - _host_messages 會在state 之中注冊host的變化,就是往其 _host_messages 之中放入"host 有變化" 的消息。
- 因為這個消息不是一定要立即處理的,所以這里只是先放入 State 的隊列之中。
邏輯如下:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES |
| | update_res != no_update ??? +-----+ | +--------------------------+ +----------------------------+
| | + | | | | | |
| | | | | | WorkerNotificationClient | | WorkerNotificationService |
| | | v | notify_hosts_updated | | HostsUpdatedRequest | |
| | | NO | | | | |
| | | _notify_workers_host_changes+------------------------> | | +-------------------> | |
| | v | | | | |
| +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+
| | |
| | |
| | handle_hosts_updated |
v | |
+-------------------->+ v
+------------------+-----------+
| |
| WorkerNotificationManager |
+-----------+ +----------+ +----------+ | |
| | | | | | | |
| State 1 | | State 2 | ...... | State n | <---------------------+ _listeners |
| | | | | | | |
+-----------+ +----------+ +----------+ | |
| |
^ ^ ^ | |
| | | | |
on_hosts_updated | | on_hosts_updated | on_hosts_updated | |
| | | | |
+--------------+-------------------+-------------------------+ handle_hosts_updated |
| |
+------------------------------+
手機如下:
3.2.5 何時處理
何時處理這個通知?在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被調用時,state.check_host_updates 會從 _host_messages 中讀取消息,積累更新。
如 check_host_updates 方法中注釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出 HostsUpdateInterrupt 異常,具體同步使用 _bcast_object(然后內部調用到了 MPI)。
我們接下來就會在 State 的介紹之中,講解check_host_updates 。
0x04 狀態抽象
Horovod 實現了一個 State 對象,這是把機器訓練的模型又做了一步抽象。
每一個Worker擁有一個 State 對象。
-
Horovod 把所有需要在workers之間同步的變量都放進 hvd.elastic.State (比如model parameters,optimizer state,當前epoch和batch進度等等)對象之中。
-
State 對象的作用是定期存儲訓練狀態,在需要時候從 State 對象中恢復機器學習的狀態。這樣在某些worker發生意外錯誤時,可以避免因為狀態被損壞而無法恢復現場。
-
比如,假設一個worker剛好在參數更新過程中突然掛掉,而此時部分梯度更新可能只更新到一半,這個狀態是不可逆而又無法繼續,導致參數是被損壞狀態無法用於恢復訓練。
4.1 State
State 的作用是:在不同的 worker 之中跟蹤內存狀態。
主要變量&方法是:
- on_reset : 當需要重啟狀態時候調用;
- on_hosts_updated :當有 host 變化時候調用,即 向 _host_messages 這個 queue 放入一個消息;
- commit :用戶會定期調用此函數,會存儲狀態(state)到內存,檢查 host 更改;
- 當有異常發生時,會拋出一個 HorovodInternalError 異常,當 hvd.elastic.run 捕獲到這個異常后,會利用最新一次commit中恢復所有狀態。
- 因為commit狀態代價高昂(比如如參數量太大會導致耗時過長),所以需要在"每個batch的處理時間"與"如果出錯,訓練需要從多久前的狀態恢復"之間選取一個平衡點。比如,如果你每訓練10個batches就commit一次,你就把復制時間降低了10倍。但是當發生錯誤時,你需要回滾到10個batches前的狀態。
- check_host_updates : 會從
_host_messages
中讀取消息,積累更新,如方法中注釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。具體同步使用_bcast_object
(然后內部調用到了 MPI);- 如果節點發現腳本可以預見到某個節點是需要被移除或新增,Elastic Horvod可以避免回滾操作。當Driver的定時進程通過節點發現腳本發現某一個節點被標記為新增或者移除時,它將發送一個通知到所有workers,於是在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被調用時,會拋出一個 HostsUpdateInterrupt 異常。這個異常類似於 HorovodInternalError 異常,但是參數狀態等不會從最近一次commit中恢復,而是從當前實時的參數中恢復。
- 一般來說,如果你的硬件設施是可靠與穩定的,並且你的編排系統會在任務節點移除時提供足夠的告警,你就可低頻次調用 state.commit() 函數,同時只在每個batch結束時調用相對不耗時的 state.check_host_updates() 來檢查節點變更情況。
- _reset_callbacks :用戶可以注冊一些回調函數到 hvd.elastic.State 對象中,用於響應worker成員發生變化的情況。
- 比如回調函數可以處理如下情況:
- 當worker數量發生改變時,學習率需要根據新的world size進行相應改變。
- 對數據集進行重新分區。
- 這些回調函數會在"Horovod被重啟之后"和"狀態在節點間同步之前"這兩個階段中間被調用。
- 比如回調函數可以處理如下情況:
具體定義如下:
class State(object):
"""State representation used for tracking in memory state across workers.
Args:
bcast_object: Function used to broadcast a variable from rank 0 to the other workers.
get_rank: Function that returns the current rank of this worker.
"""
def __init__(self, bcast_object, get_rank):
self._bcast_object = bcast_object
self._rank = get_rank
self._host_messages = queue.Queue()
self._last_updated_timestamp = 0
self._reset_callbacks = []
def on_reset(self):
self._host_messages = queue.Queue()
self.reset()
for callback in self._reset_callbacks:
callback()
def on_hosts_updated(self, timestamp, update_res):
self._host_messages.put((timestamp, update_res))
def commit(self):
self.save()
self.check_host_updates()
def check_host_updates(self):
"""Checks that a notification has been sent indicating that hosts can be added or will be removed.
Raises a `HostsUpdatedInterrupt` if such a notification has been received.
"""
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
# 遍歷更新消息,如果更新時間戳大於上次更新時間戳,就觸發一個HostUpdateResult
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
# In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
# 會從 `_host_messages` 中讀取消息,積累更新,如方法中注釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時拋出異常。具體同步使用 `_bcast_object`(然后內部調用到了 MPI)
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
因此,我們加入 Commit 之后,邏輯如圖:
<--------------------^
+ |
| thread loop |
| |
| +----------------+------------------------------------+
| | ElasticDriver._discovery_thread |
| | + |
| | | |
| | v |
| | HostManager.update_available_hosts |
| | + |
| | | |
| | | |
| | v YES |
| | update_res != no_update ??? +-----+ | +--------------------------+ +----------------------------+
| | + | | | | | |
| | | | | | WorkerNotificationClient | | WorkerNotificationService |
| | | v | notify_hosts_updated | | HostsUpdatedRequest | |
| | | NO | | | | |
| | | _notify_workers_host_changes+------------------------> | | +-------------------> | |
| | v | | | | |
| +-----------------------------------------------------+ +--------------------------+ +----------------+-----------+
| | |
| | |
| | _bcast_object handle_hosts_updated |
v | |
+-------------------->+ +-------------+----------------------+ v
| | | +------------------+-----------+
| | | | |
v v v | WorkerNotificationManager |
+--------------------+ +----+------+ +---+------+ +------+---+ | |
| | | | | | | | | |
| Python xxx.py +-------------------------------------> | State 1 | | State 2 | ...... | State n | <---------------------+ _listeners |
| | commit / check_host_updates | | | | | | | |
+--------------------+ +-----------+ +----------+ +----------+ | |
| |
^ ^ ^ | |
| | | | |
on_hosts_updated | | on_hosts_updated | on_hosts_updated | |
| | | | |
+--------------+-------------------+-------------------------+ handle_hosts_updated |
| |
+------------------------------+
具體如下:
我們接下來介紹各級派生類。
4.2 ObjectState
ObjectState 的目的是組裝成 simple Python objects。
class ObjectState(State):
"""State for simple Python objects.
Every object is specified as a keyword argument, and will be assigned as an attribute.
Args:
bcast_object: Horovod broadcast object function used to sync state dictionary.
get_rank: Horovod rank function used to identify is this process is the coordinator.
kwargs: Properties to sync, will be exposed as attributes of the object.
"""
def __init__(self, bcast_object, get_rank, **kwargs):
self._bcast_object = bcast_object
self._saved_state = kwargs
self._set_attrs()
super(ObjectState, self).__init__(bcast_object=bcast_object, get_rank=get_rank)
def save(self):
new_state = {}
for attr in self._saved_state.keys():
new_state[attr] = getattr(self, attr)
self._saved_state = new_state
def restore(self):
self._set_attrs()
def sync(self):
if self._saved_state:
self._saved_state = self._bcast_object(self._saved_state)
self._set_attrs()
def _set_attrs(self):
for attr, value in self._saved_state.items():
setattr(self, attr, value)
4.3 TensorFlowKerasState
Horovod 默認已提供標准的TensorFlow,Keras和PyTorch的狀態保持和恢復實現,如果需要在某些場景下自定義,可以重載 hvd.elastic.State 這個對象。
TensorFlowKerasState 是 TensorFlow Keras model and optimizer 的狀態抽象。
初始化函數中,會設置各種相關變量,比如廣播函數。
class TensorFlowKerasState(ObjectState):
def __init__(self, model, optimizer=None, backend=None, **kwargs):
self.model = model
if not _model_built(model):
raise ValueError('Model must be built first. Run `model.build(input_shape)`.')
self.optimizer = optimizer or model.optimizer
self.backend = backend
self._save_model()
if not backend or _executing_eagerly():
self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
bcast_object = broadcast_object
else:
# For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
bcast_op = broadcast_variables(_global_variables(), root_rank=0)
self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
bcast_object = broadcast_object_fn(session=self.backend.get_session())
super(TensorFlowKerasState, self).__init__(bcast_object=bcast_object,
get_rank=rank,
**kwargs)
具體實現了幾個方法,基本就是 存儲,恢復 state,同步。
def save(self):
self._save_model()
super(TensorFlowKerasState, self).save()
def restore(self):
self._load_model()
super(TensorFlowKerasState, self).restore()
def sync(self):
self._bcast_model()
self._save_model()
super(TensorFlowKerasState, self).sync()
def _save_model(self):
if _executing_eagerly():
self._saved_model_state = [tf.identity(var) for var in self.model.variables]
self._saved_optimizer_state = [tf.identity(var) for var in self.optimizer.variables()]
else:
self._saved_model_state = self.model.get_weights()
self._saved_optimizer_state = self.optimizer.get_weights()
def _load_model(self):
if _executing_eagerly():
for var, saved_var in zip(self.model.variables, self._saved_model_state):
var.assign(saved_var)
for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
var.assign(saved_var)
else:
self.model.set_weights(self._saved_model_state)
self.optimizer.set_weights(self._saved_optimizer_state)
4.4 Restore
我們看到了,restore 會從內存中恢復模型。
def restore(self):
self._load_model()
super(TensorFlowKerasState, self).restore()
於是,我們有一個問題:何時調用restore?
發現是如果 horovod 捕獲了 HorovodInternalError 之后,會用 restore 來恢復。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync()
try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore() # 在這里調用
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper
0x05 總結
我們再次重復一下,發現節點機制的幾個關鍵設計點:
- 有節點變化時候,如何即時發現?Horovod是通過定期調用完成。
- 發現節點變化時候,如何通知各個worker? Horovod通過構建了一個通知機制完成。即,每個worker把自己注冊到WorkerNotificationManager 之上,當有節點變化時候,WorkerNotificationManager 會逐一通知這些worker。
- worker得到通知之后,如何處理?Horovod 把worker的狀態在深度框架上進一步封裝成各種State,得到通知之后就會調用State的對應callback函數,或者同步狀態,或者進行其他處理。
簡化版總體邏輯如下:
+-----------------------------v
^ thread loop |
| |
+----------------+----------------------+ |
| ElasticDriver._discovery_thread | |
_notify_workers_host_changes | | |
| | |
+------------------+ | |
| | | |
| | HostManager.update_available_hosts | |
| | | |
| +-----------------+---------------------+ |
| ^ |
| | |
| | |
| +----------<---------------+ v
v
+---------------------------+ HostsUpdatedReques +----------------------------+ handle_hosts_updated +----------------------------+
| | | | | |
| WorkerNotificationClient +----------------------> | WorkerNotificationService | +------------------> | WorkerNotificationManager |
| | | | | |
+---------------------------+ +----------------------------+ +--------+-------------------+
|
|
| on_hosts_updated
|
v
+----+---+
| State |
+--------+
手機如下:
至此,發現節點部分介紹完畢,因為本文只是使用了 WorkerNotificationService 完成通知,但是沒有深入介紹,所以下一篇介紹內部廣播和通知機制。
0xEE 個人信息
★★★★★★關於生活和技術的思考★★★★★★
微信公眾賬號:羅西的思考
如果您想及時得到個人撰寫文章的消息推送,或者想看看個人推薦的技術資料,敬請關注。