averager.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. """ A background process that averages your tensors with peers """
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import ctypes
  6. import multiprocessing as mp
  7. import os
  8. import random
  9. import threading
  10. import weakref
  11. from dataclasses import asdict
  12. from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
  13. import numpy as np
  14. import torch
  15. from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
  16. from hivemind.averaging.control import AveragingStage, StepControl
  17. from hivemind.averaging.group_info import GroupInfo
  18. from hivemind.averaging.load_balancing import load_balance_peers
  19. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  20. from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
  21. from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
  22. from hivemind.dht import DHT, DHTID
  23. from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
  24. from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
  25. from hivemind.proto import averaging_pb2
  26. from hivemind.utils import MPFuture, TensorDescriptor, get_logger
  27. from hivemind.utils.asyncio import (
  28. achain,
  29. aiter_with_timeout,
  30. anext,
  31. as_aiter,
  32. azip,
  33. enter_asynchronously,
  34. switch_to_uvloop,
  35. )
  36. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  37. from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
  38. from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
  39. # flavour types
  40. GatheredData = Any
  41. logger = get_logger(__name__)
  42. class DecentralizedAverager(mp.Process, ServicerBase):
  43. """
  44. Parameter averaging service. A trainer can run this service in background to periodically average his parameters
  45. with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
  46. group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
  47. :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
  48. :param dht: a DHT node that will be used to find groups
  49. :param start: if True, starts the background process immediately
  50. :param prefix: a shared prefix for all group keys
  51. :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
  52. :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
  53. :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
  54. :param compression: optionally compress tensors with this compression algorithm before running all-reduce
  55. :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
  56. :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
  57. :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
  58. towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
  59. :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
  60. :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
  61. :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
  62. :param bandwidth: if specified, this value represents the network bandwidth available to averager.
  63. By default, the averager is assumed to have the average bandwidth of his group.
  64. If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
  65. :param client_mode: if False, this averager will accept incoming requests from other peers.
  66. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  67. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  68. :param auxiliary: if this flag is specified, averager.step will only assist others without sending
  69. local tensors for averaging
  70. :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
  71. with averager.allow_state_sharing = True / False
  72. :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
  73. :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
  74. :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
  75. this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
  76. :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
  77. previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
  78. :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
  79. from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
  80. :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
  81. Example:
  82. >>> averager = DecentralizedAverager(...)
  83. >>> with averager.get_tensors() as tensors:
  84. >>> # run some code, modify tensors if necessary
  85. >>> tensors[0] += 1
  86. >>> # do not use tensors after the lock is released
  87. >>> metadata = averager.step(gather=dict(my_batch_size=32))
  88. >>> # run averaging once (in-place), gather metadata from groupmates
  89. >>> with averager.get_tensors() as tensors_after_averaging:
  90. >>> pass # use the averaged tensors
  91. """
  92. _matchmaking: Matchmaking
  93. _pending_groups_registered: asyncio.Event
  94. _state_updated: asyncio.Event
  95. _p2p: P2P
  96. serializer = MSGPackSerializer
  97. def __init__(
  98. self,
  99. averaged_tensors: Sequence[torch.Tensor],
  100. dht: DHT,
  101. *,
  102. start: bool,
  103. prefix: str,
  104. target_group_size: Optional[int] = None,
  105. min_group_size: int = 2,
  106. initial_group_bits: str = "",
  107. min_matchmaking_time: float = 5.0,
  108. request_timeout: float = 3.0,
  109. averaging_alpha: float = 1.0,
  110. part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
  111. allreduce_timeout: Optional[float] = None,
  112. next_chunk_timeout: Optional[float] = None,
  113. sender_timeout: Optional[float] = None,
  114. reducer_timeout: Optional[float] = None,
  115. compression: CompressionBase = NoCompression(),
  116. state_compression: CompressionBase = NoCompression(),
  117. tensor_infos: Optional[Sequence[CompressionInfo]] = None,
  118. bandwidth: Optional[float] = None,
  119. min_vector_size: int = 0,
  120. auxiliary: bool = False,
  121. allow_state_sharing: Optional[bool] = None,
  122. declare_state_period: float = 30,
  123. client_mode: Optional[bool] = None,
  124. daemon: bool = True,
  125. shutdown_timeout: float = 5,
  126. ):
  127. assert "." not in prefix, "group prefix must be a string without trailing '.'"
  128. assert bandwidth is None or (
  129. bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
  130. ), "bandwidth must be a non-negative float32"
  131. assert all(bit in "01" for bit in initial_group_bits)
  132. assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
  133. super().__init__()
  134. self.dht = dht
  135. self.prefix = prefix
  136. if client_mode is None:
  137. client_mode = dht.client_mode
  138. if sender_timeout is None:
  139. sender_timeout = next_chunk_timeout
  140. if reducer_timeout is None:
  141. reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
  142. self.client_mode = client_mode
  143. self._parent_pid = os.getpid()
  144. if self.client_mode:
  145. self.mode = AveragingMode.CLIENT
  146. elif auxiliary:
  147. self.mode = AveragingMode.AUX
  148. else:
  149. self.mode = AveragingMode.NODE
  150. self.daemon = daemon
  151. self._averaged_tensors = tuple(averaged_tensors)
  152. self.lock_averaged_tensors = mp.Lock()
  153. for tensor in self._averaged_tensors:
  154. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  155. tensor.share_memory_()
  156. self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
  157. self.schema_hash = compute_schema_hash(self._averaged_tensors)
  158. self.shutdown_timeout = shutdown_timeout
  159. self.next_chunk_timeout = next_chunk_timeout
  160. self.bandwidth = bandwidth
  161. self.matchmaking_kwargs = dict(
  162. servicer_type=type(self),
  163. prefix=prefix,
  164. initial_group_bits=initial_group_bits,
  165. target_group_size=target_group_size,
  166. min_group_size=min_group_size,
  167. request_timeout=request_timeout,
  168. min_matchmaking_time=min_matchmaking_time,
  169. )
  170. self.allreduce_kwargs = dict(
  171. compression=compression,
  172. part_size_bytes=part_size_bytes,
  173. min_vector_size=min_vector_size,
  174. sender_timeout=sender_timeout,
  175. reducer_timeout=reducer_timeout,
  176. )
  177. self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
  178. self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
  179. self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
  180. self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
  181. self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
  182. if allow_state_sharing is None:
  183. allow_state_sharing = not client_mode and not auxiliary
  184. self.allow_state_sharing = allow_state_sharing
  185. self.declare_state_period = declare_state_period
  186. self.state_compression = state_compression
  187. self.tensor_infos = tensor_infos
  188. self._ready = MPFuture()
  189. # note: we create a background thread weakref and with daemon=True to ensure garbage collection
  190. background_fetcher = threading.Thread(
  191. daemon=True,
  192. target=_background_thread_fetch_current_state,
  193. args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)],
  194. )
  195. background_fetcher.start()
  196. if start:
  197. self.run_in_background(await_ready=True)
  198. @property
  199. def allow_state_sharing(self) -> bool:
  200. """if set to True, other peers can download this peer's state"""
  201. return bool(self._allow_state_sharing.value)
  202. @allow_state_sharing.setter
  203. def allow_state_sharing(self, value: bool):
  204. if value and self.client_mode:
  205. raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
  206. else:
  207. old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
  208. if value != old_value:
  209. self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
  210. @property
  211. def state_sharing_priority(self) -> float:
  212. """Others will preferentially downloading state from peers with highest priority."""
  213. return float(self._state_sharing_priority.value)
  214. @state_sharing_priority.setter
  215. def state_sharing_priority(self, value: float):
  216. if value and self.client_mode:
  217. raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
  218. else:
  219. old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
  220. if self.allow_state_sharing and value != old_value:
  221. self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
  222. async def _trigger_declare_load_state(self):
  223. # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
  224. self._state_updated.set()
  225. @property
  226. def peer_id(self) -> PeerID:
  227. return self.dht.peer_id
  228. @property
  229. def request_timeout(self):
  230. return self._matchmaking.request_timeout
  231. def run(self):
  232. """
  233. Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
  234. Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
  235. Read more: https://github.com/pytorch/pytorch/issues/17199
  236. """
  237. thread = threading.Thread(target=self._run_internal, daemon=True)
  238. thread.start()
  239. thread.join()
  240. def _run_internal(self):
  241. """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
  242. loop = switch_to_uvloop()
  243. # initialize asyncio synchronization primitives in this event loop
  244. pipe_semaphore = asyncio.Semaphore(value=0)
  245. loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
  246. async def _run():
  247. try:
  248. self._p2p = await self.dht.replicate_p2p()
  249. if not self.client_mode:
  250. await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
  251. else:
  252. logger.debug("The averager is running in client mode")
  253. self._matchmaking = Matchmaking(
  254. self._p2p,
  255. self.schema_hash,
  256. self.dht,
  257. client_mode=self.client_mode,
  258. **self.matchmaking_kwargs,
  259. )
  260. if not self.client_mode:
  261. asyncio.create_task(self._declare_for_download_periodically())
  262. self._state_updated = asyncio.Event()
  263. self._pending_groups_registered = asyncio.Event()
  264. self._pending_groups_registered.set()
  265. except Exception as e:
  266. # Loglevel is DEBUG since normally the exception is propagated to the caller
  267. logger.debug(e, exc_info=True)
  268. self._ready.set_exception(e)
  269. return
  270. self._ready.set_result(None)
  271. while True:
  272. try:
  273. await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
  274. except asyncio.TimeoutError:
  275. pass
  276. if not self._inner_pipe.poll():
  277. continue
  278. try:
  279. method, args, kwargs = self._inner_pipe.recv()
  280. except (OSError, ConnectionError, RuntimeError) as e:
  281. logger.exception(e)
  282. await asyncio.sleep(self.request_timeout)
  283. continue
  284. task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
  285. if method == "_shutdown":
  286. await task
  287. break
  288. loop.run_until_complete(_run())
  289. def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
  290. """
  291. Starts averager in a background process. if await_ready, this method will wait until background dht
  292. is ready to process incoming requests or for :timeout: seconds max.
  293. """
  294. self.start()
  295. if await_ready:
  296. self.wait_until_ready(timeout)
  297. def wait_until_ready(self, timeout: Optional[float] = None) -> None:
  298. self._ready.result(timeout=timeout)
  299. def shutdown(self) -> None:
  300. """Shut down the averager process"""
  301. if self.is_alive():
  302. self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {})) # shut down the daemon process
  303. self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
  304. self.join(self.shutdown_timeout)
  305. if self.is_alive():
  306. logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
  307. self.terminate()
  308. else:
  309. logger.exception("Averager shutdown has no effect: the process is already not alive")
  310. async def _shutdown(self, timeout: Optional[float]) -> None:
  311. remaining_tasks = set()
  312. for group in self._running_groups.values():
  313. remaining_tasks.update(group.finalize(cancel=True))
  314. await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
  315. def __del__(self):
  316. if self._parent_pid == os.getpid() and self.is_alive():
  317. self.shutdown()
  318. def step(
  319. self,
  320. gather: Optional[GatheredData] = None,
  321. scheduled_time: Optional[DHTExpiration] = None,
  322. weight: Optional[float] = None,
  323. timeout: Optional[float] = None,
  324. allow_retries: bool = True,
  325. require_trigger: bool = False,
  326. wait: bool = True,
  327. ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
  328. """
  329. Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
  330. :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
  331. (this operation is known as all-gather). The gathered data will be available as the output of this function.
  332. :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
  333. By default, schedule all-reduce current time plus min_matchmaking_time seconds
  334. :param weight: averaging weight for this peer, int or float, must be strictly positive
  335. :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
  336. within the specified timeout
  337. :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
  338. :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
  339. :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
  340. :returns: on success, update averaged_tensors and return group info; on failure, return None
  341. """
  342. if self.mode == AveragingMode.AUX and weight is not None:
  343. logger.warning("Averager is running in auxiliary mode, weight is unused")
  344. if scheduled_time is None:
  345. scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
  346. if weight is None:
  347. weight = float(self.mode != AveragingMode.AUX)
  348. deadline = get_dht_time() + timeout if timeout is not None else float("inf")
  349. assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
  350. assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
  351. assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
  352. user_data_for_gather = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
  353. data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
  354. step = StepControl(
  355. scheduled_time=scheduled_time,
  356. deadline=deadline,
  357. allow_retries=allow_retries,
  358. weight=weight,
  359. data_for_gather=data_for_gather,
  360. )
  361. future_for_init = MPFuture()
  362. self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
  363. step.attach(*future_for_init.result())
  364. if not require_trigger:
  365. step.allow_allreduce()
  366. return step.result() if wait else step
  367. async def _step(self, *, step: StepControl, future_for_init: MPFuture):
  368. try:
  369. trigger, cancel = MPFuture(), MPFuture()
  370. step.attach(trigger, cancel)
  371. future_for_init.set_result((trigger, cancel))
  372. async def find_peers_or_notify_cancel():
  373. group_info = await self._matchmaking.look_for_group(step)
  374. if not step.triggered:
  375. step.stage = AveragingStage.AWAITING_TRIGGER
  376. await step.wait_for_trigger()
  377. return group_info
  378. while not step.done():
  379. try:
  380. self._pending_groups_registered.clear()
  381. step.stage = AveragingStage.LOOKING_FOR_GROUP
  382. matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
  383. check_cancel_task = asyncio.create_task(step.wait_for_cancel())
  384. await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
  385. if step.cancelled():
  386. matchmaking_task.cancel()
  387. raise asyncio.CancelledError()
  388. else:
  389. check_cancel_task.cancel()
  390. group_info = await matchmaking_task
  391. if group_info is None:
  392. raise AllreduceException("Averaging step failed: could not find a group")
  393. with self._register_allreduce_group(group_info):
  394. step.stage = AveragingStage.RUNNING_ALLREDUCE
  395. step.set_result(
  396. await asyncio.wait_for(
  397. self._aggregate_with_group(
  398. group_info,
  399. tensor_infos=self.tensor_infos,
  400. weight=step.weight,
  401. **self.allreduce_kwargs,
  402. ),
  403. timeout=self._allreduce_timeout,
  404. )
  405. )
  406. # averaging is finished, loop will now exit
  407. except (
  408. AllreduceException,
  409. MatchmakingException,
  410. AssertionError,
  411. StopAsyncIteration,
  412. asyncio.CancelledError,
  413. asyncio.InvalidStateError,
  414. P2PHandlerError,
  415. DispatchFailure,
  416. ControlFailure,
  417. ) as e:
  418. if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
  419. if not step.cancelled():
  420. logger.exception(e)
  421. if not step.done():
  422. step.set_exception(e)
  423. else:
  424. logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
  425. except BaseException as e:
  426. if not step.done():
  427. step.set_exception(e)
  428. raise
  429. finally:
  430. step.stage = AveragingStage.FINISHED
  431. if not step.done():
  432. step.set_exception(
  433. RuntimeError(
  434. "Internal sanity check failed: averager.step left future pending."
  435. " Please report this to hivemind issues."
  436. )
  437. )
  438. @contextlib.contextmanager
  439. def _register_allreduce_group(self, group_info: GroupInfo):
  440. """Register a given group for one or more all-reduce rounds"""
  441. try:
  442. self._running_groups[group_info.group_id] = asyncio.Future()
  443. self._pending_groups_registered.set()
  444. yield
  445. finally:
  446. maybe_future = self._running_groups.pop(group_info.group_id, None)
  447. if maybe_future is not None and not maybe_future.done():
  448. logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
  449. self._pending_groups_registered.set()
  450. async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  451. """Run aggregation in a given group and update tensors in place, return gathered metadata"""
  452. try:
  453. bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
  454. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
  455. modes = tuple(map(AveragingMode, mode_ids))
  456. # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
  457. download_bandwidths = [
  458. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  459. ]
  460. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  461. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  462. )
  463. async with enter_asynchronously(self.get_tensors()) as local_tensors:
  464. await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
  465. return user_gathered
  466. except BaseException as e:
  467. if isinstance(e, Exception):
  468. logger.exception(e)
  469. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  470. async def _run_allreduce_inplace_(
  471. self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
  472. ):
  473. """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
  474. group_id = group_info.group_id if group_id is None else group_id
  475. runner = AllReduceRunner(
  476. p2p=self._p2p,
  477. servicer_type=type(self),
  478. prefix=self.prefix,
  479. tensors=tensors,
  480. group_id=group_id,
  481. ordered_peer_ids=group_info.peer_ids,
  482. **kwargs,
  483. )
  484. assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
  485. self._running_groups[group_id].set_result(runner)
  486. if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  487. async for tensor, update in azip(as_aiter(*tensors), runner):
  488. tensor.add_(update, alpha=self._averaging_alpha)
  489. self.last_updated = get_dht_time()
  490. self._state_updated.set()
  491. else:
  492. async for _ in runner:
  493. raise ValueError("aux peers should not receive averaged tensors")
  494. @contextlib.contextmanager
  495. def get_tensors(self) -> Sequence[torch.Tensor]:
  496. """
  497. A contextmanager that gives user access to averaged tensors.
  498. It is guaranteed that the averager will not modify tensors while this context is active.
  499. Please do not modify the yielded tensors in-place after the context is released.
  500. """
  501. with self.lock_averaged_tensors:
  502. yield self._averaged_tensors
  503. async def rpc_join_group(
  504. self, request: averaging_pb2.JoinRequest, context: P2PContext
  505. ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
  506. """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
  507. async for response in self._matchmaking.rpc_join_group(request, context):
  508. yield response
  509. async def rpc_aggregate_part(
  510. self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
  511. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  512. """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
  513. request = await anext(stream)
  514. if request.group_id not in self._running_groups:
  515. # this handles a special case when leader accepted us to group AND began allreduce right away,
  516. # but his response with group_id was delayed and other peers got to us first
  517. await self._pending_groups_registered.wait()
  518. future = self._running_groups.get(request.group_id)
  519. if future is None:
  520. yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  521. return
  522. group = await future
  523. async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
  524. yield message
  525. async def _declare_for_download_periodically(self):
  526. download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
  527. sharing_was_allowed = self.allow_state_sharing
  528. while True:
  529. expiration_time = get_dht_time() + self.declare_state_period
  530. if self.allow_state_sharing or sharing_was_allowed:
  531. # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
  532. asyncio.create_task(
  533. asyncio.wait_for(
  534. self.dht.store(
  535. download_key,
  536. subkey=self.peer_id.to_bytes(),
  537. value=self.state_sharing_priority if self.allow_state_sharing else None,
  538. expiration_time=expiration_time,
  539. return_future=True,
  540. ),
  541. timeout=expiration_time - get_dht_time(),
  542. )
  543. )
  544. sharing_was_allowed = self.allow_state_sharing
  545. # report again either in state_declare_period or after the field was changed by the user
  546. self._state_updated.clear()
  547. try:
  548. await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
  549. except asyncio.TimeoutError:
  550. pass
  551. async def rpc_download_state(
  552. self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
  553. ) -> AsyncIterator[averaging_pb2.DownloadData]:
  554. """
  555. Get the up-to-date trainer state from a peer.
  556. The state consists of two parts: (serialized_metadata, tensors)
  557. - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
  558. - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
  559. """
  560. if not self.allow_state_sharing:
  561. return # deny request and direct peer to the next prospective averager
  562. metadata, tensors, infos = await self._get_current_state_from_host_process()
  563. if infos is None:
  564. infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
  565. assert len(tensors) == len(infos)
  566. for tensor, info in zip(tensors, infos):
  567. for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
  568. if metadata is not None:
  569. yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
  570. metadata = None
  571. else:
  572. yield averaging_pb2.DownloadData(tensor_part=part)
  573. def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
  574. """
  575. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  576. :returns: a tuple of (small metadata, sequence of torch tensors)
  577. :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
  578. """
  579. with self.get_tensors() as tensors:
  580. return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
  581. async def _get_current_state_from_host_process(self):
  582. """Executed in the averager process inside rpc_download_state"""
  583. future = MPFuture()
  584. self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
  585. return await future
  586. def load_state_from_peers(
  587. self, wait: bool = True, timeout: Optional[float] = None
  588. ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
  589. """
  590. Try to download the latest optimizer state one of the existing peer.
  591. :returns: on success, return a 2-tuple with (metadata, tensors), where
  592. - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
  593. - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
  594. The exact contents of both metadata and tensors are determined by get_current_state method
  595. """
  596. future = MPFuture()
  597. self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
  598. return future.result(timeout=timeout) if wait else future
  599. async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
  600. if timeout is not None:
  601. timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
  602. try:
  603. key_manager = self._matchmaking.group_key_manager
  604. peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
  605. peer_priority = {
  606. PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
  607. for peer_id, info in peer_priority.items()
  608. if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
  609. }
  610. if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
  611. logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
  612. future.set_result(None)
  613. return
  614. metadata = None
  615. for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
  616. if peer != self.peer_id:
  617. logger.info(f"Downloading parameters from peer {peer}")
  618. try:
  619. stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
  620. stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
  621. current_tensor_parts, tensors = [], []
  622. # TODO merge this with hivemind.compression.deserialize_tensor_stream
  623. async for message in aiter_with_timeout(stream, timeout=timeout):
  624. if message.metadata:
  625. metadata = self.serializer.loads(message.metadata)
  626. if message.tensor_part.dtype and current_tensor_parts:
  627. # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
  628. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  629. current_tensor_parts = []
  630. current_tensor_parts.append(message.tensor_part)
  631. if current_tensor_parts:
  632. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  633. if not metadata:
  634. logger.debug(f"Peer {peer} did not send its state")
  635. continue
  636. logger.info(f"Finished downloading state from {peer}")
  637. future.set_result((metadata, tensors))
  638. return
  639. except Exception as e:
  640. logger.exception(f"Failed to download state from {peer} - {repr(e)}")
  641. finally:
  642. if not future.done():
  643. future.set_result(None)
  644. def get_group_bits(self, wait: bool = True):
  645. """
  646. :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
  647. :returns: averager's current group key bits (without prefix)
  648. """
  649. future = MPFuture()
  650. self._outer_pipe.send(("_get_group_bits", [], dict(future=future)))
  651. return future.result() if wait else future
  652. async def _get_group_bits(self, future: MPFuture):
  653. future.set_result(self._matchmaking.group_key_manager.group_bits)
  654. def set_group_bits(self, group_bits: str, wait: bool = True):
  655. """
  656. :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
  657. :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
  658. """
  659. future = MPFuture()
  660. assert all(bit in "01" for bit in group_bits)
  661. self._outer_pipe.send(("_set_group_bits", [], dict(group_bits=group_bits, future=future)))
  662. return future.result() if wait else future
  663. async def _set_group_bits(self, group_bits: str, future: MPFuture):
  664. try:
  665. self._matchmaking.group_key_manager.group_bits = group_bits
  666. return future.set_result(None)
  667. except Exception as e:
  668. if not future.done():
  669. future.set_exception(e)
  670. def _background_thread_fetch_current_state(
  671. serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
  672. ):
  673. """
  674. Executed in the host process as a background thread. Fetches the averager state when asked by peers.
  675. :param serializer: a serializer with which to convert metadata into bytes
  676. :param pipe: DecentralizedAverager's control pipe (from host process side)
  677. :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
  678. """
  679. while True:
  680. try:
  681. trigger, future = pipe.recv()
  682. except BaseException as e:
  683. logger.debug(f"Averager background thread finished: {repr(e)}")
  684. break
  685. if trigger == "_SHUTDOWN":
  686. break
  687. assert trigger == "_TRIGGER_GET_CURRENT_STATE"
  688. try:
  689. get_current_state = get_current_state_ref()
  690. if get_current_state is None:
  691. break
  692. state = get_current_state()
  693. assert 0 < len(state) <= 3
  694. if len(state) != 3:
  695. state = tuple(state + (None,) * (3 - len(state)))
  696. state_metadata, state_tensors, tensor_infos = state
  697. del get_current_state
  698. state_metadata = serializer.dumps(state_metadata)
  699. state_tensors = tuple(
  700. tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
  701. )
  702. # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
  703. future.set_result((state_metadata, state_tensors, tensor_infos))
  704. except BaseException as e:
  705. future.set_exception(e)
  706. logger.warning(e)
  707. continue
  708. def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
  709. """A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values"""
  710. schema_dicts = [
  711. {
  712. field_name: str(field_value)
  713. for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()
  714. }
  715. for tensor in tensors
  716. ]
  717. return DHTID.generate(source=schema_dicts).to_bytes()