averager.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788
  1. """ A background process that averages your tensors with peers """
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import ctypes
  6. import multiprocessing as mp
  7. import os
  8. import random
  9. import threading
  10. import weakref
  11. from dataclasses import asdict
  12. from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
  13. import numpy as np
  14. import torch
  15. from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
  16. from hivemind.averaging.control import AveragingStage, StepControl
  17. from hivemind.averaging.group_info import GroupInfo
  18. from hivemind.averaging.load_balancing import load_balance_peers
  19. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  20. from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
  21. from hivemind.compression import (
  22. CompressionBase,
  23. CompressionInfo,
  24. NoCompression,
  25. deserialize_torch_tensor,
  26. serialize_torch_tensor,
  27. )
  28. from hivemind.dht import DHT, DHTID
  29. from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
  30. from hivemind.proto import averaging_pb2
  31. from hivemind.utils import MPFuture, TensorDescriptor, get_logger
  32. from hivemind.utils.asyncio import (
  33. achain,
  34. aiter_with_timeout,
  35. anext,
  36. as_aiter,
  37. azip,
  38. enter_asynchronously,
  39. switch_to_uvloop,
  40. )
  41. from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
  42. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  43. from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
  44. # flavour types
  45. GatheredData = Any
  46. logger = get_logger(__name__)
  47. class DecentralizedAverager(mp.Process, ServicerBase):
  48. """
  49. Parameter averaging service. A trainer can run this service in background to periodically average his parameters
  50. with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
  51. group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
  52. :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
  53. :param dht: a DHT node that will be used to find groups
  54. :param start: if True, starts the background process immediately
  55. :param prefix: a shared prefix for all group keys
  56. :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
  57. :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
  58. :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
  59. :param compression: optionally compress tensors with this compression algorithm before running all-reduce
  60. :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
  61. :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
  62. :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
  63. :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
  64. towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
  65. :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
  66. :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
  67. :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
  68. :param bandwidth: if specified, this value represents the network bandwidth available to averager.
  69. By default, the averager is assumed to have the average bandwidth of his group.
  70. If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
  71. :param client_mode: if False, this averager will accept incoming requests from other peers.
  72. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  73. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  74. :param auxiliary: if this flag is specified, averager.step will only assist others without sending
  75. local tensors for averaging
  76. :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
  77. with averager.allow_state_sharing = True / False
  78. :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
  79. :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
  80. Example:
  81. >>> averager = DecentralizedAverager(...)
  82. >>> with averager.get_tensors() as tensors:
  83. >>> # run some code, modify tensors if necessary
  84. >>> tensors[0] += 1
  85. >>> # do not use tensors after the lock is released
  86. >>> metadata = averager.step(gather=dict(my_batch_size=32))
  87. >>> # run averaging once (in-place), gather metadata from groupmates
  88. >>> with averager.get_tensors() as tensors_after_averaging:
  89. >>> pass # use the averaged tensors
  90. """
  91. _matchmaking: Matchmaking
  92. _pending_group_assembled: asyncio.Event
  93. _state_updated: asyncio.Event
  94. _p2p: P2P
  95. serializer = MSGPackSerializer
  96. def __init__(
  97. self,
  98. averaged_tensors: Sequence[torch.Tensor],
  99. dht: DHT,
  100. *,
  101. start: bool,
  102. prefix: str,
  103. target_group_size: Optional[int] = None,
  104. min_group_size: int = 2,
  105. initial_group_bits: str = "",
  106. averaging_expiration: Optional[float] = None,
  107. min_matchmaking_time: float = 5.0,
  108. request_timeout: float = 3.0,
  109. averaging_alpha: float = 1.0,
  110. part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
  111. allreduce_timeout: Optional[float] = None,
  112. compression: CompressionBase = NoCompression(),
  113. state_compression: CompressionBase = NoCompression(),
  114. tensor_infos: Optional[Sequence[CompressionInfo]] = None,
  115. bandwidth: Optional[float] = None,
  116. min_vector_size: int = 0,
  117. auxiliary: bool = False,
  118. allow_state_sharing: Optional[bool] = None,
  119. declare_state_period: float = 30,
  120. client_mode: Optional[bool] = None,
  121. daemon: bool = True,
  122. shutdown_timeout: float = 5,
  123. ):
  124. assert "." not in prefix, "group prefix must be a string without trailing '.'"
  125. assert bandwidth is None or (
  126. bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
  127. ), "bandwidth must be a non-negative float32"
  128. assert all(bit in "01" for bit in initial_group_bits)
  129. assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
  130. if averaging_expiration is not None:
  131. logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
  132. assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
  133. min_matchmaking_time = averaging_expiration
  134. super().__init__()
  135. self.dht = dht
  136. self.prefix = prefix
  137. if client_mode is None:
  138. client_mode = dht.client_mode
  139. self.client_mode = client_mode
  140. self._parent_pid = os.getpid()
  141. if self.client_mode:
  142. self.mode = AveragingMode.CLIENT
  143. elif auxiliary:
  144. self.mode = AveragingMode.AUX
  145. else:
  146. self.mode = AveragingMode.NODE
  147. self.daemon = daemon
  148. self._averaged_tensors = tuple(averaged_tensors)
  149. self.lock_averaged_tensors = mp.Lock()
  150. for tensor in self._averaged_tensors:
  151. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  152. tensor.share_memory_()
  153. self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
  154. self.schema_hash = compute_schema_hash(self._averaged_tensors)
  155. self.shutdown_timeout = shutdown_timeout
  156. self.bandwidth = bandwidth
  157. self.matchmaking_kwargs = dict(
  158. servicer_type=type(self),
  159. prefix=prefix,
  160. initial_group_bits=initial_group_bits,
  161. target_group_size=target_group_size,
  162. min_group_size=min_group_size,
  163. request_timeout=request_timeout,
  164. min_matchmaking_time=min_matchmaking_time,
  165. )
  166. self.allreduce_kwargs = dict(
  167. compression=compression,
  168. part_size_bytes=part_size_bytes,
  169. min_vector_size=min_vector_size,
  170. )
  171. self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
  172. self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
  173. self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
  174. self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
  175. self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
  176. if allow_state_sharing is None:
  177. allow_state_sharing = not client_mode and not auxiliary
  178. self.allow_state_sharing = allow_state_sharing
  179. self.declare_state_period = declare_state_period
  180. self.state_compression = state_compression
  181. self.tensor_infos = tensor_infos
  182. self._ready = MPFuture()
  183. # note: we create a background thread weakref and with daemon=True to ensure garbage collection
  184. background_fetcher = threading.Thread(
  185. daemon=True,
  186. target=_background_thread_fetch_current_state,
  187. args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)],
  188. )
  189. background_fetcher.start()
  190. if start:
  191. self.run_in_background(await_ready=True)
  192. @property
  193. def allow_state_sharing(self) -> bool:
  194. """if set to True, other peers can download this peer's state"""
  195. return bool(self._allow_state_sharing.value)
  196. @allow_state_sharing.setter
  197. def allow_state_sharing(self, value: bool):
  198. if value and self.client_mode:
  199. raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
  200. else:
  201. old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
  202. if value != old_value:
  203. self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
  204. @property
  205. def state_sharing_priority(self) -> float:
  206. """Others will preferentially downloading state from peers with highest priority."""
  207. return float(self._state_sharing_priority.value)
  208. @state_sharing_priority.setter
  209. def state_sharing_priority(self, value: float):
  210. if value and self.client_mode:
  211. raise ValueError("State sharing priority is unused: averager in client mode cannot share its state.")
  212. else:
  213. old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
  214. if self.allow_state_sharing and value != old_value:
  215. self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
  216. async def _trigger_declare_load_state(self):
  217. # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
  218. self._state_updated.set()
  219. @property
  220. def peer_id(self) -> PeerID:
  221. return self.dht.peer_id
  222. @property
  223. def request_timeout(self):
  224. return self._matchmaking.request_timeout
  225. def run(self):
  226. """
  227. Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
  228. Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
  229. Read more: https://github.com/pytorch/pytorch/issues/17199
  230. """
  231. thread = threading.Thread(target=self._run_internal, daemon=True)
  232. thread.start()
  233. thread.join()
  234. def _run_internal(self):
  235. """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
  236. loop = switch_to_uvloop()
  237. # initialize asyncio synchronization primitives in this event loop
  238. pipe_semaphore = asyncio.Semaphore(value=0)
  239. loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
  240. async def _run():
  241. try:
  242. self._p2p = await self.dht.replicate_p2p()
  243. if not self.client_mode:
  244. await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
  245. else:
  246. logger.debug(f"The averager is running in client mode.")
  247. self._matchmaking = Matchmaking(
  248. self._p2p,
  249. self.schema_hash,
  250. self.dht,
  251. client_mode=self.client_mode,
  252. **self.matchmaking_kwargs,
  253. )
  254. if not self.client_mode:
  255. asyncio.create_task(self._declare_for_download_periodically())
  256. self._state_updated = asyncio.Event()
  257. self._pending_group_assembled = asyncio.Event()
  258. self._pending_group_assembled.set()
  259. except Exception as e:
  260. # Loglevel is DEBUG since normally the exception is propagated to the caller
  261. logger.debug(e, exc_info=True)
  262. self._ready.set_exception(e)
  263. return
  264. self._ready.set_result(None)
  265. while True:
  266. try:
  267. await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
  268. except asyncio.TimeoutError:
  269. pass
  270. if not self._inner_pipe.poll():
  271. continue
  272. try:
  273. method, args, kwargs = self._inner_pipe.recv()
  274. except (OSError, ConnectionError, RuntimeError) as e:
  275. logger.exception(e)
  276. await asyncio.sleep(self.request_timeout)
  277. continue
  278. task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
  279. if method == "_shutdown":
  280. await task
  281. break
  282. loop.run_until_complete(_run())
  283. def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
  284. """
  285. Starts averager in a background process. if await_ready, this method will wait until background dht
  286. is ready to process incoming requests or for :timeout: seconds max.
  287. """
  288. self.start()
  289. if await_ready:
  290. self.wait_until_ready(timeout)
  291. def wait_until_ready(self, timeout: Optional[float] = None) -> None:
  292. self._ready.result(timeout=timeout)
  293. def shutdown(self) -> None:
  294. """Shut down the averager process"""
  295. if self.is_alive():
  296. self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {})) # shut down the daemon process
  297. self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
  298. self.join(self.shutdown_timeout)
  299. if self.is_alive():
  300. logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
  301. self.terminate()
  302. else:
  303. logger.exception("Averager shutdown has no effect: the process is already not alive")
  304. async def _shutdown(self, timeout: Optional[float]) -> None:
  305. remaining_tasks = set()
  306. for group in self._running_groups.values():
  307. remaining_tasks.update(group.finalize(cancel=True))
  308. await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
  309. def __del__(self):
  310. if self._parent_pid == os.getpid() and self.is_alive():
  311. self.shutdown()
  312. def step(
  313. self,
  314. gather: Optional[GatheredData] = None,
  315. scheduled_time: Optional[DHTExpiration] = None,
  316. weight: Optional[float] = None,
  317. timeout: Optional[float] = None,
  318. allow_retries: bool = True,
  319. require_trigger: bool = False,
  320. wait: bool = True,
  321. ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
  322. """
  323. Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
  324. :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
  325. (this operation is known as all-gather). The gathered data will be available as the output of this function.
  326. :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
  327. By default, schedule all-reduce current time plus min_matchmaking_time seconds
  328. :param weight: averaging weight for this peer, int or float, must be strictly positive
  329. :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
  330. within the specified timeout
  331. :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
  332. :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
  333. :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
  334. :returns: on success, update averaged_tensors and return group info; on failure, return None
  335. """
  336. if self.mode == AveragingMode.AUX and weight is not None:
  337. logger.warning("Averager is running in auxiliary mode, weight is unused.")
  338. if scheduled_time is None:
  339. scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
  340. if weight is None:
  341. weight = float(self.mode != AveragingMode.AUX)
  342. deadline = get_dht_time() + timeout if timeout is not None else float("inf")
  343. assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
  344. assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
  345. assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
  346. user_data_for_gather = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
  347. data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
  348. step = StepControl(
  349. scheduled_time=scheduled_time,
  350. deadline=deadline,
  351. allow_retries=allow_retries,
  352. weight=weight,
  353. data_for_gather=data_for_gather,
  354. )
  355. future_for_init = MPFuture()
  356. self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
  357. step.attach(*future_for_init.result())
  358. if not require_trigger:
  359. step.allow_allreduce()
  360. return step.result() if wait else step
  361. async def _step(self, *, step: StepControl, future_for_init: MPFuture):
  362. try:
  363. trigger, cancel = MPFuture(), MPFuture()
  364. step.attach(trigger, cancel)
  365. future_for_init.set_result((trigger, cancel))
  366. while not step.done():
  367. try:
  368. self._pending_group_assembled.clear()
  369. step.stage = AveragingStage.LOOKING_FOR_GROUP
  370. matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
  371. check_cancel_task = asyncio.create_task(step.wait_for_cancel())
  372. await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
  373. if step.cancelled():
  374. matchmaking_task.cancel()
  375. raise asyncio.CancelledError()
  376. else:
  377. check_cancel_task.cancel()
  378. group_info = await matchmaking_task
  379. if group_info is None:
  380. raise AllreduceException("Averaging step failed: could not find a group.")
  381. if not step.triggered:
  382. step.stage = AveragingStage.AWAITING_TRIGGER
  383. await step.wait_for_trigger()
  384. step.stage = AveragingStage.RUNNING_ALLREDUCE
  385. step.set_result(
  386. await asyncio.wait_for(
  387. self._run_allreduce(
  388. group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
  389. ),
  390. timeout=self._allreduce_timeout,
  391. )
  392. )
  393. # averaging is finished, loop will now exit
  394. except (
  395. AllreduceException,
  396. MatchmakingException,
  397. AssertionError,
  398. StopAsyncIteration,
  399. asyncio.CancelledError,
  400. asyncio.InvalidStateError,
  401. P2PHandlerError,
  402. ) as e:
  403. if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
  404. if not step.cancelled():
  405. logger.exception(e)
  406. if not step.done():
  407. step.set_exception(e)
  408. else:
  409. logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
  410. except BaseException as e:
  411. if not step.done():
  412. step.set_exception(e)
  413. raise
  414. finally:
  415. step.stage = AveragingStage.FINISHED
  416. if not step.done():
  417. step.set_exception(
  418. RuntimeError(
  419. "Internal sanity check failed: averager.step left future pending."
  420. " Please report this to hivemind issues."
  421. )
  422. )
  423. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  424. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  425. try:
  426. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  427. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  428. modes = tuple(map(AveragingMode, mode_ids))
  429. # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
  430. download_bandwidths = [
  431. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  432. ]
  433. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  434. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  435. )
  436. async with enter_asynchronously(self.get_tensors()) as local_tensors:
  437. allreduce = AllReduceRunner(
  438. p2p=self._p2p,
  439. servicer_type=type(self),
  440. prefix=self.prefix,
  441. group_id=group_info.group_id,
  442. tensors=local_tensors,
  443. ordered_peer_ids=group_info.peer_ids,
  444. peer_fractions=peer_fractions,
  445. gathered=user_gathered,
  446. modes=modes,
  447. **kwargs,
  448. )
  449. with self.register_allreduce_group(group_info.group_id, allreduce):
  450. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  451. async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
  452. # all-reduce is performed asynchronously while iterating
  453. tensor.add_(update, alpha=self._averaging_alpha)
  454. self._state_updated.set()
  455. else:
  456. async for _ in allreduce: # trigger all-reduce by iterating
  457. raise ValueError("aux peers should not receive averaged tensors")
  458. return allreduce.gathered
  459. except BaseException as e:
  460. logger.exception(e)
  461. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  462. @contextlib.contextmanager
  463. def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
  464. """registers a given all-reduce runner to listen for incoming connections"""
  465. try:
  466. self._running_groups[group_id] = allreduce
  467. self._pending_group_assembled.set()
  468. yield
  469. finally:
  470. self._running_groups.pop(group_id, None)
  471. self._pending_group_assembled.set()
  472. @contextlib.contextmanager
  473. def get_tensors(self) -> Sequence[torch.Tensor]:
  474. """
  475. A contextmanager that gives user access to averaged tensors.
  476. It is guaranteed that the averager will not modify tensors while this context is active.
  477. Please do not modify the yielded tensors in-place after the context is released.
  478. """
  479. with self.lock_averaged_tensors:
  480. yield self._averaged_tensors
  481. async def rpc_join_group(
  482. self, request: averaging_pb2.JoinRequest, context: P2PContext
  483. ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
  484. """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
  485. async for response in self._matchmaking.rpc_join_group(request, context):
  486. yield response
  487. async def rpc_aggregate_part(
  488. self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
  489. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  490. """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
  491. request = await anext(stream)
  492. if request.group_id not in self._running_groups:
  493. # this handles a special case when leader accepted us to group AND began allreduce right away,
  494. # but his response with group_id was delayed and other peers got to us first
  495. await self._pending_group_assembled.wait()
  496. group = self._running_groups.get(request.group_id)
  497. if group is None:
  498. yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  499. return
  500. async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
  501. yield message
  502. async def _declare_for_download_periodically(self):
  503. download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
  504. sharing_was_allowed = self.allow_state_sharing
  505. while True:
  506. expiration_time = get_dht_time() + self.declare_state_period
  507. if self.allow_state_sharing or sharing_was_allowed:
  508. # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
  509. asyncio.create_task(
  510. asyncio.wait_for(
  511. self.dht.store(
  512. download_key,
  513. subkey=self.peer_id.to_bytes(),
  514. value=self.state_sharing_priority if self.allow_state_sharing else None,
  515. expiration_time=expiration_time,
  516. return_future=True,
  517. ),
  518. timeout=expiration_time - get_dht_time(),
  519. )
  520. )
  521. sharing_was_allowed = self.allow_state_sharing
  522. # report again either in state_declare_period or after the field was changed by the user
  523. self._state_updated.clear()
  524. try:
  525. await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
  526. except asyncio.TimeoutError:
  527. pass
  528. async def rpc_download_state(
  529. self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
  530. ) -> AsyncIterator[averaging_pb2.DownloadData]:
  531. """
  532. Get the up-to-date trainer state from a peer.
  533. The state consists of two parts: (serialized_metadata, tensors)
  534. - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
  535. - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
  536. """
  537. if not self.allow_state_sharing:
  538. return # deny request and direct peer to the next prospective averager
  539. metadata, tensors, infos = await self._get_current_state_from_host_process()
  540. if infos is None:
  541. infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
  542. assert len(tensors) == len(infos)
  543. for tensor, info in zip(tensors, infos):
  544. for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
  545. if metadata is not None:
  546. yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
  547. metadata = None
  548. else:
  549. yield averaging_pb2.DownloadData(tensor_part=part)
  550. def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
  551. """
  552. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  553. :returns: a tuple of (small metadata, sequence of torch tensors)
  554. :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
  555. """
  556. with self.get_tensors() as tensors:
  557. return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
  558. async def _get_current_state_from_host_process(self):
  559. """Executed in the averager process inside rpc_download_state"""
  560. future = MPFuture()
  561. self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
  562. return await future
  563. def load_state_from_peers(
  564. self, wait: bool = True, timeout: Optional[float] = None
  565. ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
  566. """
  567. Try to download the latest optimizer state one of the existing peer.
  568. :returns: on success, return a 2-tuple with (metadata, tensors), where
  569. - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
  570. - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
  571. The exact contents of both metadata and tensors are determined by get_current_state method
  572. """
  573. future = MPFuture()
  574. self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
  575. return future.result(timeout=timeout) if wait else future
  576. async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
  577. try:
  578. key_manager = self._matchmaking.group_key_manager
  579. peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
  580. peer_priority = {
  581. PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
  582. for peer_id, info in peer_priority.items()
  583. if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
  584. }
  585. if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
  586. logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
  587. future.set_result(None)
  588. return
  589. metadata = None
  590. for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
  591. if peer != self.peer_id:
  592. logger.info(f"Downloading parameters from peer {peer}")
  593. try:
  594. stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
  595. stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
  596. current_tensor_parts, tensors = [], []
  597. async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
  598. if message.metadata:
  599. metadata = self.serializer.loads(message.metadata)
  600. if message.tensor_part.dtype and current_tensor_parts:
  601. # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
  602. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  603. current_tensor_parts = []
  604. current_tensor_parts.append(message.tensor_part)
  605. if current_tensor_parts:
  606. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  607. if not metadata:
  608. logger.debug(f"Peer {peer} did not send its state.")
  609. continue
  610. logger.info(f"Finished downloading state from {peer}")
  611. future.set_result((metadata, tensors))
  612. return
  613. except Exception as e:
  614. logger.exception(f"Failed to download state from {peer} - {repr(e)}")
  615. finally:
  616. if not future.done():
  617. future.set_result(None)
  618. def get_group_bits(self, wait: bool = True):
  619. """
  620. :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
  621. :returns: averager's current group key bits (without prefix)
  622. """
  623. future = MPFuture()
  624. self._outer_pipe.send(("_get_group_bits", [], dict(future=future)))
  625. return future.result() if wait else future
  626. async def _get_group_bits(self, future: MPFuture):
  627. future.set_result(self._matchmaking.group_key_manager.group_bits)
  628. def set_group_bits(self, group_bits: str, wait: bool = True):
  629. """
  630. :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
  631. :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
  632. """
  633. future = MPFuture()
  634. assert all(bit in "01" for bit in group_bits)
  635. self._outer_pipe.send(("_set_group_bits", [], dict(group_bits=group_bits, future=future)))
  636. return future.result() if wait else future
  637. async def _set_group_bits(self, group_bits: str, future: MPFuture):
  638. try:
  639. self._matchmaking.group_key_manager.group_bits = group_bits
  640. return future.set_result(None)
  641. except Exception as e:
  642. if not future.done():
  643. future.set_exception(e)
  644. def _background_thread_fetch_current_state(
  645. serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
  646. ):
  647. """
  648. Executed in the host process as a background thread. Fetches the averager state when asked by peers.
  649. :param serializer: a serializer with which to convert metadata into bytes
  650. :param pipe: DecentralizedAverager's control pipe (from host process side)
  651. :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
  652. """
  653. while True:
  654. try:
  655. trigger, future = pipe.recv()
  656. except BaseException as e:
  657. logger.debug(f"Averager background thread finished: {repr(e)}")
  658. break
  659. if trigger == "_SHUTDOWN":
  660. break
  661. assert trigger == "_TRIGGER_GET_CURRENT_STATE"
  662. try:
  663. get_current_state = get_current_state_ref()
  664. if get_current_state is None:
  665. break
  666. state = get_current_state()
  667. assert 0 < len(state) <= 3
  668. if len(state) != 3:
  669. state = tuple(state + (None,) * (3 - len(state)))
  670. state_metadata, state_tensors, tensor_infos = state
  671. del get_current_state
  672. state_metadata = serializer.dumps(state_metadata)
  673. state_tensors = tuple(
  674. tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
  675. )
  676. # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
  677. future.set_result((state_metadata, state_tensors, tensor_infos))
  678. except BaseException as e:
  679. future.set_exception(e)
  680. logger.warning(e)
  681. continue
  682. def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
  683. """A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values"""
  684. schema_dicts = [
  685. {
  686. field_name: str(field_value)
  687. for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()
  688. }
  689. for tensor in tensors
  690. ]
  691. return DHTID.generate(source=schema_dicts).to_bytes()