averager.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. """ A background process that averages your tensors with peers """
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import ctypes
  6. import multiprocessing as mp
  7. import os
  8. import threading
  9. import weakref
  10. from concurrent.futures.thread import ThreadPoolExecutor
  11. from dataclasses import asdict
  12. from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
  13. import numpy as np
  14. import torch
  15. from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
  16. from hivemind.averaging.group_info import GroupInfo
  17. from hivemind.averaging.load_balancing import load_balance_peers
  18. from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
  19. from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
  20. from hivemind.dht import DHT, DHTID
  21. from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
  22. from hivemind.proto import averaging_pb2, runtime_pb2
  23. from hivemind.utils import MPFuture, TensorDescriptor, get_logger
  24. from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
  25. from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
  26. from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
  27. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  28. from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
  29. # flavour types
  30. GatheredData = Any
  31. logger = get_logger(__name__)
  32. class DecentralizedAverager(mp.Process, ServicerBase):
  33. """
  34. Parameter averaging service. A trainer can run this service in background to periodically average his parameters
  35. with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
  36. group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
  37. :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
  38. :param dht: a DHT node that will be used to find groups
  39. :param start: if True, starts the background process immediately
  40. :param prefix: a shared prefix for all group keys
  41. :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
  42. :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
  43. :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
  44. note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
  45. :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
  46. :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
  47. :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
  48. towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
  49. :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
  50. :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
  51. :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
  52. :param bandwidth: if specified, this value represents the network bandwidth available to averager.
  53. By default, the averager is assumed to have the average bandwidth of his group.
  54. If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
  55. :param client_mode: if False, this averager will accept incoming requests from other peers.
  56. if True, the averager will only join existing groups where at least one peer has client_mode=False.
  57. By default, this flag is copied from DHTNode inside the ``dht`` instance.
  58. :param auxiliary: if this flag is specified, averager.step will only assist others without sending
  59. local tensors for averaging
  60. :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
  61. with averager.allow_state_sharing = True / False
  62. :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
  63. Example:
  64. >>> averager = DecentralizedAverager(...)
  65. >>> with averager.get_tensors() as tensors:
  66. >>> # run some code, modify tensors if necessary
  67. >>> tensors[0] += 1
  68. >>> # do not use tensors after the lock is released
  69. >>> metadata = averager.step(gather=dict(my_batch_size=32))
  70. >>> # run averaging once (in-place), gather metadata from groupmates
  71. >>> with averager.get_tensors() as tensors_after_averaging:
  72. >>> pass # use the averaged tensors
  73. """
  74. _matchmaking: Matchmaking
  75. _pending_group_assembled: asyncio.Event
  76. serializer = MSGPackSerializer
  77. def __init__(
  78. self,
  79. averaged_tensors: Sequence[torch.Tensor],
  80. dht: DHT,
  81. *,
  82. start: bool,
  83. prefix: str,
  84. target_group_size: int,
  85. min_group_size: int = 2,
  86. initial_group_bits: str = "",
  87. averaging_expiration: float = 15,
  88. request_timeout: float = 3,
  89. averaging_alpha: float = 1.0,
  90. part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
  91. allreduce_timeout: Optional[float] = None,
  92. compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
  93. bandwidth: Optional[float] = None,
  94. min_vector_size: int = 0,
  95. auxiliary: bool = False,
  96. allow_state_sharing: Optional[bool] = None,
  97. client_mode: Optional[bool] = None,
  98. daemon: bool = True,
  99. shutdown_timeout: float = 5,
  100. ):
  101. assert "." not in prefix, "group prefix must be a string without trailing '.'"
  102. assert bandwidth is None or (
  103. bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
  104. ), "bandwidth must be a non-negative float32"
  105. if not is_power_of_two(target_group_size):
  106. logger.warning("It is recommended to set target_group_size to a power of 2.")
  107. assert all(bit in "01" for bit in initial_group_bits)
  108. assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
  109. super().__init__()
  110. self.dht = dht
  111. self.prefix = prefix
  112. if client_mode is None:
  113. client_mode = dht.client_mode
  114. self.client_mode = client_mode
  115. self._parent_pid = os.getpid()
  116. if self.client_mode:
  117. self.mode = AveragingMode.CLIENT
  118. elif auxiliary:
  119. self.mode = AveragingMode.AUX
  120. else:
  121. self.mode = AveragingMode.NODE
  122. self.daemon = daemon
  123. self._averaged_tensors = tuple(averaged_tensors)
  124. self.lock_averaged_tensors = mp.Lock()
  125. self.last_updated: DHTExpiration = -float("inf")
  126. for tensor in self._averaged_tensors:
  127. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  128. tensor.share_memory_()
  129. self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
  130. self.schema_hash = compute_schema_hash(self._averaged_tensors)
  131. self.shutdown_timeout = shutdown_timeout
  132. self.bandwidth = bandwidth
  133. self.matchmaking_kwargs = dict(
  134. servicer_type=type(self),
  135. prefix=prefix,
  136. initial_group_bits=initial_group_bits,
  137. target_group_size=target_group_size,
  138. min_group_size=min_group_size,
  139. averaging_expiration=averaging_expiration,
  140. request_timeout=request_timeout,
  141. )
  142. self.allreduce_kwargs = dict(
  143. compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
  144. )
  145. self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
  146. self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
  147. self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
  148. self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
  149. if allow_state_sharing is None:
  150. allow_state_sharing = not client_mode and not auxiliary
  151. self.allow_state_sharing = allow_state_sharing
  152. self._ready = MPFuture()
  153. # note: we create a background thread weakref and with daemon=True to ensure garbage collection
  154. background_fetcher = threading.Thread(
  155. daemon=True,
  156. target=_background_thread_fetch_current_state,
  157. args=[self.serializer, self._outer_pipe, weakref.WeakMethod(self.get_current_state)],
  158. )
  159. background_fetcher.start()
  160. if start:
  161. self.run_in_background(await_ready=True)
  162. @property
  163. def allow_state_sharing(self) -> bool:
  164. """if set to True, other peers can download this peer's state"""
  165. return bool(self._allow_state_sharing.value)
  166. @allow_state_sharing.setter
  167. def allow_state_sharing(self, value: bool):
  168. if value and self.client_mode:
  169. raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
  170. else:
  171. self._allow_state_sharing.value = value
  172. @property
  173. def peer_id(self) -> PeerID:
  174. return self.dht.peer_id
  175. def run(self):
  176. """
  177. Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
  178. Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
  179. Read more: https://github.com/pytorch/pytorch/issues/17199
  180. """
  181. thread = threading.Thread(target=self._run_internal, daemon=True)
  182. thread.start()
  183. thread.join()
  184. def _run_internal(self):
  185. """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
  186. loop = switch_to_uvloop()
  187. # initialize asyncio synchronization primitives in this event loop
  188. with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
  189. async def _run():
  190. try:
  191. self._p2p = await self.dht.replicate_p2p()
  192. if not self.client_mode:
  193. await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
  194. else:
  195. logger.debug(f"The averager is running in client mode.")
  196. self._matchmaking = Matchmaking(
  197. self._p2p,
  198. self.schema_hash,
  199. self.dht,
  200. client_mode=self.client_mode,
  201. **self.matchmaking_kwargs,
  202. )
  203. logger.debug(f"The 1")
  204. if not self.client_mode:
  205. asyncio.create_task(self._declare_for_download_periodically())
  206. self._pending_group_assembled = asyncio.Event()
  207. self._pending_group_assembled.set()
  208. logger.debug(f"The 2")
  209. except Exception as e:
  210. # Loglevel is DEBUG since normally the exception is propagated to the caller
  211. logger.debug(e, exc_info=True)
  212. self._ready.set_exception(e)
  213. return
  214. self._ready.set_result(None)
  215. logger.debug(f"The 3")
  216. while True:
  217. try:
  218. method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
  219. logger.debug(f"The 4")
  220. except (OSError, ConnectionError) as e:
  221. logger.exception(e)
  222. await asyncio.sleep(self._matchmaking.request_timeout)
  223. continue
  224. task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
  225. if method == "_shutdown":
  226. await task
  227. break
  228. loop.run_until_complete(_run())
  229. def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
  230. """
  231. Starts averager in a background process. if await_ready, this method will wait until background dht
  232. is ready to process incoming requests or for :timeout: seconds max.
  233. """
  234. self.start()
  235. if await_ready:
  236. self.wait_until_ready(timeout)
  237. def wait_until_ready(self, timeout: Optional[float] = None) -> None:
  238. self._ready.result(timeout=timeout)
  239. def shutdown(self) -> None:
  240. """Shut down the averager process"""
  241. if self.is_alive():
  242. self._outer_pipe.send(("_shutdown", [None], {})) # shut down the daemon process
  243. self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
  244. self.join(self.shutdown_timeout)
  245. if self.is_alive():
  246. logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
  247. self.terminate()
  248. else:
  249. logger.exception("Averager shutdown has no effect: the process is already not alive")
  250. async def _shutdown(self, timeout: Optional[float] = None) -> None:
  251. remaining_tasks = set()
  252. for group in self._running_groups.values():
  253. remaining_tasks.update(group.finalize(cancel=True))
  254. await asyncio.gather(*remaining_tasks)
  255. def __del__(self):
  256. if self._parent_pid == os.getpid() and self.is_alive():
  257. self.shutdown()
  258. def step(
  259. self,
  260. gather: Optional[GatheredData] = None,
  261. weight: Optional[float] = None,
  262. timeout: Optional[float] = None,
  263. allow_retries: bool = True,
  264. wait: bool = True,
  265. ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
  266. """
  267. Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
  268. :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
  269. (this operation is known as all-gather). The gathered data will be available as the output of this function.
  270. :param weight: averaging weight for this peer, int or float, must be strictly positive
  271. :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
  272. within the specified timeout
  273. :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
  274. :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
  275. :returns: on success, update averaged_tensors and return group info; on failure, return None
  276. """
  277. if self.mode == AveragingMode.AUX and weight is not None:
  278. logger.warning("Averager is running in auxiliary mode, weight is unused.")
  279. if weight is None:
  280. weight = float(self.mode != AveragingMode.AUX)
  281. assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
  282. future = MPFuture()
  283. gather_binary = self.serializer.dumps(
  284. gather
  285. ) # serialize here to avoid loading modules in the averager process
  286. self._outer_pipe.send(
  287. (
  288. "_step",
  289. [],
  290. dict(
  291. future=future,
  292. gather_binary=gather_binary,
  293. weight=weight,
  294. allow_retries=allow_retries,
  295. timeout=timeout,
  296. ),
  297. )
  298. )
  299. return future.result() if wait else future
  300. async def _step(
  301. self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
  302. ):
  303. start_time = get_dht_time()
  304. try:
  305. while not future.done():
  306. try:
  307. self._pending_group_assembled.clear()
  308. data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
  309. group_info = await self._matchmaking.look_for_group(
  310. timeout=timeout, data_for_gather=data_for_gather
  311. )
  312. if group_info is None:
  313. raise AllreduceException("Averaging step failed: could not find a group.")
  314. future.set_result(
  315. await asyncio.wait_for(
  316. self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
  317. )
  318. )
  319. # averaging is finished, loop will now exit
  320. except (
  321. AllreduceException,
  322. MatchmakingException,
  323. AssertionError,
  324. StopAsyncIteration,
  325. asyncio.CancelledError,
  326. asyncio.InvalidStateError,
  327. P2PHandlerError,
  328. ) as e:
  329. time_elapsed = get_dht_time() - start_time
  330. if not allow_retries or (timeout is not None and timeout < time_elapsed):
  331. logger.exception(f"Averager caught {repr(e)}")
  332. future.set_exception(e)
  333. else:
  334. logger.warning(f"Averager caught {repr(e)}, retrying")
  335. except BaseException as e:
  336. if not future.done():
  337. future.set_exception(e)
  338. raise
  339. finally:
  340. if not future.done():
  341. future.set_exception(
  342. RuntimeError(
  343. "Internal sanity check failed: averager.step left future pending."
  344. " Please report this to hivemind issues."
  345. )
  346. )
  347. async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
  348. """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
  349. try:
  350. weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
  351. user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
  352. modes = tuple(map(AveragingMode, mode_ids))
  353. # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
  354. download_bandwidths = [
  355. thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
  356. ]
  357. peer_fractions = await asyncio.get_event_loop().run_in_executor(
  358. None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
  359. )
  360. async with self.get_tensors_async() as local_tensors:
  361. allreduce = AllReduceRunner(
  362. p2p=self._p2p,
  363. servicer_type=type(self),
  364. prefix=self.prefix,
  365. group_id=group_info.group_id,
  366. tensors=local_tensors,
  367. ordered_peer_ids=group_info.peer_ids,
  368. peer_fractions=peer_fractions,
  369. weights=weights,
  370. gathered=user_gathered,
  371. modes=modes,
  372. **kwargs,
  373. )
  374. with self.register_allreduce_group(group_info.group_id, allreduce):
  375. # actually run all-reduce
  376. averaging_outputs = [output async for output in allreduce]
  377. if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
  378. assert len(local_tensors) == len(self._averaged_tensors)
  379. for tensor, update in zip(local_tensors, averaging_outputs):
  380. tensor.add_(update, alpha=self._averaging_alpha)
  381. self.last_updated = get_dht_time()
  382. return allreduce.gathered
  383. except BaseException as e:
  384. logger.exception(e)
  385. raise MatchmakingException(f"Unable to run All-Reduce: {e}")
  386. @contextlib.contextmanager
  387. def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
  388. """registers a given all-reduce runner to listen for incoming connections"""
  389. try:
  390. self._running_groups[group_id] = allreduce
  391. self._pending_group_assembled.set()
  392. yield
  393. finally:
  394. self._running_groups.pop(group_id, None)
  395. self._pending_group_assembled.set()
  396. @contextlib.contextmanager
  397. def get_tensors(self) -> Sequence[torch.Tensor]:
  398. """
  399. A contextmanager that gives user access to averaged tensors.
  400. It is guaranteed that the averager will not modify tensors while this context is active.
  401. Please do not modify the yielded tensors in-place after the context is released.
  402. """
  403. with self.lock_averaged_tensors:
  404. yield self._averaged_tensors
  405. self.last_updated = get_dht_time()
  406. @contextlib.asynccontextmanager
  407. async def get_tensors_async(self) -> Sequence[torch.Tensor]:
  408. """Like get_tensors, but uses an asynchronous contextmanager"""
  409. try:
  410. await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
  411. yield self._averaged_tensors
  412. finally:
  413. self.lock_averaged_tensors.release()
  414. async def rpc_join_group(
  415. self, request: averaging_pb2.JoinRequest, context: P2PContext
  416. ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
  417. """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
  418. async for response in self._matchmaking.rpc_join_group(request, context):
  419. yield response
  420. async def rpc_aggregate_part(
  421. self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
  422. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  423. """a groupmate sends us a part of his tensor; we should average it with other peers and return the result"""
  424. request = await anext(stream)
  425. if request.group_id not in self._running_groups:
  426. # this handles a special case when leader accepted us to group AND began allreduce right away,
  427. # but his response with group_id was delayed and other peers got to us first
  428. await self._pending_group_assembled.wait()
  429. group = self._running_groups.get(request.group_id)
  430. if group is None:
  431. yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  432. return
  433. async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
  434. yield message
  435. async def _declare_for_download_periodically(self):
  436. download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
  437. while True:
  438. if self.allow_state_sharing:
  439. asyncio.create_task(
  440. asyncio.wait_for(
  441. self.dht.store(
  442. download_key,
  443. subkey=self.peer_id.to_bytes(),
  444. value=self.last_updated,
  445. expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
  446. return_future=True,
  447. ),
  448. timeout=self._matchmaking.averaging_expiration,
  449. )
  450. )
  451. await asyncio.sleep(self._matchmaking.averaging_expiration)
  452. async def rpc_download_state(
  453. self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
  454. ) -> AsyncIterator[averaging_pb2.DownloadData]:
  455. """
  456. Get the up-to-date trainer state from a peer.
  457. The state consists of two parts: (serialized_metadata, tensors)
  458. - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
  459. - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
  460. """
  461. if not self.allow_state_sharing:
  462. return # deny request and direct peer to the next prospective averager
  463. metadata, tensors = await self._get_current_state_from_host_process()
  464. for tensor in tensors:
  465. for part in split_for_streaming(serialize_torch_tensor(tensor)):
  466. if metadata is not None:
  467. yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
  468. metadata = None
  469. else:
  470. yield averaging_pb2.DownloadData(tensor_part=part)
  471. def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
  472. """
  473. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  474. :returns: a tuple of (small metadata, sequence of torch tensors)
  475. :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
  476. """
  477. with self.get_tensors() as tensors:
  478. return dict(group_key=self.get_group_bits()), tensors
  479. async def _get_current_state_from_host_process(self):
  480. """Executed in the averager process inside rpc_download_state"""
  481. future = MPFuture()
  482. self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
  483. return await future
  484. def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
  485. """
  486. Try to download the latest optimizer state one of the existing peer.
  487. :returns: on success, return a 2-tuple with (metadata, tensors), where
  488. - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
  489. - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
  490. The exact contents of both metadata and tensors are determined by get_current_state method
  491. """
  492. future = MPFuture()
  493. self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
  494. return future.result() if wait else future
  495. async def _load_state_from_peers(self, future: MPFuture):
  496. try:
  497. key_manager = self._matchmaking.group_key_manager
  498. peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
  499. peer_priority = {
  500. PeerID(peer_id): float(info.value)
  501. for peer_id, info in peer_priority.items()
  502. if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
  503. }
  504. if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
  505. logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
  506. future.set_result(None)
  507. return
  508. metadata = None
  509. for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
  510. if peer != self.peer_id:
  511. logger.info(f"Downloading parameters from peer {peer}")
  512. try:
  513. stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
  514. stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
  515. current_tensor_parts, tensors = [], []
  516. async for message in stream:
  517. if message.metadata:
  518. metadata = self.serializer.loads(message.metadata)
  519. if message.tensor_part.dtype and current_tensor_parts:
  520. # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
  521. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  522. current_tensor_parts = []
  523. current_tensor_parts.append(message.tensor_part)
  524. if current_tensor_parts:
  525. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  526. if not metadata:
  527. logger.debug(f"Peer {peer} did not send its state.")
  528. continue
  529. logger.info(f"Finished downloading state from {peer}")
  530. future.set_result((metadata, tensors))
  531. self.last_updated = get_dht_time()
  532. return
  533. except Exception as e:
  534. logger.exception(f"Failed to download state from {peer} - {repr(e)}")
  535. finally:
  536. if not future.done():
  537. logger.warning("Averager could not load state from peers: all requests have failed.")
  538. future.set_result(None)
  539. def get_group_bits(self, wait: bool = True):
  540. """
  541. :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
  542. :returns: averager's current group key bits (without prefix)
  543. """
  544. future = MPFuture()
  545. self._outer_pipe.send(("_get_group_bits", [], dict(future=future)))
  546. return future.result() if wait else future
  547. async def _get_group_bits(self, future: MPFuture):
  548. future.set_result(self._matchmaking.group_key_manager.group_bits)
  549. def set_group_bits(self, group_bits: str, wait: bool = True):
  550. """
  551. :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
  552. :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
  553. """
  554. future = MPFuture()
  555. assert all(bit in "01" for bit in group_bits)
  556. self._outer_pipe.send(("_set_group_bits", [], dict(group_bits=group_bits, future=future)))
  557. return future.result() if wait else future
  558. async def _set_group_bits(self, group_bits: str, future: MPFuture):
  559. try:
  560. self._matchmaking.group_key_manager.group_bits = group_bits
  561. return future.set_result(None)
  562. except Exception as e:
  563. if not future.done():
  564. future.set_exception(e)
  565. def is_power_of_two(n):
  566. """Check whether n is a power of 2"""
  567. return (n != 0) and (n & (n - 1) == 0)
  568. def _background_thread_fetch_current_state(
  569. serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
  570. ):
  571. """
  572. Executed in the host process as a background thread. Fetches the averager state when asked by peers.
  573. :param serializer: a serializer with which to convert metadata into bytes
  574. :param pipe: DecentralizedAverager's control pipe (from host process side)
  575. :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
  576. """
  577. while True:
  578. try:
  579. trigger, future = pipe.recv()
  580. except BaseException as e:
  581. logger.debug(f"Averager background thread finished: {repr(e)}")
  582. break
  583. if trigger == "_SHUTDOWN":
  584. break
  585. assert trigger == "_TRIGGER_GET_CURRENT_STATE"
  586. try:
  587. get_current_state = get_current_state_ref()
  588. if get_current_state is None:
  589. break
  590. state_metadata, state_tensors = get_current_state()
  591. del get_current_state
  592. state_metadata = serializer.dumps(state_metadata)
  593. state_tensors = tuple(
  594. tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
  595. )
  596. # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
  597. future.set_result((state_metadata, state_tensors))
  598. except BaseException as e:
  599. future.set_exception(e)
  600. logger.warning(e)
  601. continue
  602. def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
  603. """A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values"""
  604. schema_dicts = [
  605. {
  606. field_name: str(field_value)
  607. for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()
  608. }
  609. for tensor in tensors
  610. ]
  611. return DHTID.generate(source=schema_dicts).to_bytes()