__init__.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. """ A background process that averages your tensors with peers """
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import ctypes
  6. import multiprocessing as mp
  7. import os
  8. import threading
  9. import uuid
  10. import weakref
  11. from concurrent.futures.thread import ThreadPoolExecutor
  12. from dataclasses import asdict
  13. from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
  14. import grpc
  15. from grpc._cython.cygrpc import InternalError
  16. import torch
  17. import numpy as np
  18. from hivemind.dht import DHT, DHTID
  19. from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
  20. from hivemind.client.averaging.load_balancing import load_balance_peers
  21. from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
  22. from hivemind.client.averaging.group_info import GroupInfo
  23. from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
  24. from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
  25. serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
  26. from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
  27. from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
  28. from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
  29. from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
  30. # flavour types
  31. StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
  32. DataForGather = Any
  33. logger = get_logger(__name__)
  34. DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
  35. class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
  36. """
  37. Parameter averaging service. A trainer can run this service in background to periodically average his parameters
  38. with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
  39. group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
  40. :param averaged_tensors: a sequence of pytorch tensors that will be averaged in each all-reduce
  41. :param dht: a DHT node that will be used to find groups
  42. :param start: if True, starts the background process immediately
  43. :param prefix: a shared prefix for all group keys
  44. :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
  45. :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
  46. :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
  47. note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
  48. :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
  49. :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
  50. :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
  51. towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
  52. :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
  53. :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
  54. :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
  55. :param throughput: if specified, this value represents the network bandwidth available to averager.
  56. By default, the averager is assumed to have the average bandwidth of his group.
  57. If throughput == 0, averager will rely on its groupmates to do all the averaging.
  58. :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
  59. if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
  60. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  61. :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
  62. see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
  63. :param kwargs: extra parameters forwarded to grpc.aio.server
  64. Example:
  65. >>> averager = DecentralizedAverager(...)
  66. >>> with averager.get_tensors() as tensors:
  67. >>> # run some code, modify tensors if necessary
  68. >>> tensors[0] += 1
  69. >>> # do not use tensors after the lock is released
  70. >>> metadata = averager.step(gather=dict(my_batch_size=32))
  71. >>> # run averaging once (in-place), gather metadata from groupmates
  72. >>> with averager.get_tensors() as tensors_after_averaging:
  73. >>> pass # use the averaged tensors
  74. """
  75. _matchmaking: Matchmaking
  76. _pending_group_assembled: asyncio.Event
  77. serializer = MSGPackSerializer
  78. def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
  79. prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
  80. averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
  81. allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
  82. compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
  83. throughput: Optional[float] = None, min_vector_size: int = 0,
  84. listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
  85. channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
  86. assert '.' not in prefix, "group prefix must be a string without trailing '.'"
  87. assert throughput is None or (throughput >= 0 and np.isfinite(np.float32(throughput))), \
  88. "throughput must be a non-negative float32"
  89. if not is_power_of_two(target_group_size):
  90. logger.warning("It is recommended to set target_group_size to a power of 2.")
  91. assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
  92. super().__init__()
  93. self.dht = dht
  94. self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
  95. self.channel_options = channel_options
  96. self.daemon = daemon
  97. self._averaged_tensors = tuple(averaged_tensors)
  98. self.lock_averaged_tensors = mp.Lock()
  99. self.last_updated: DHTExpiration = -float('inf')
  100. for tensor in self._averaged_tensors:
  101. assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
  102. tensor.share_memory_()
  103. self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
  104. self.schema_hash = compute_schema_hash(self._averaged_tensors)
  105. self._throughput = throughput
  106. self.matchmaking_kwargs = dict(
  107. prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
  108. min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
  109. self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
  110. min_vector_size=min_vector_size)
  111. self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
  112. self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
  113. self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
  114. self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
  115. self._averager_endpoint: Optional[Endpoint] = None
  116. if not self.listen:
  117. self._averager_endpoint = f'client::{uuid.uuid4()}'
  118. self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
  119. # note: we create a background thread weakref and with daemon=True to ensure garbage collection
  120. background_fetcher = threading.Thread(
  121. daemon=True, target=_background_thread_fetch_current_state,
  122. args=[self.serializer, self.pipe, weakref.WeakMethod(self.get_current_state)])
  123. background_fetcher.start()
  124. if start:
  125. self.run_in_background(await_ready=True)
  126. @property
  127. def port(self) -> Optional[Port]:
  128. return self._port.value if self._port.value != 0 else None
  129. @property
  130. def endpoint(self) -> Optional[Endpoint]:
  131. if self.listen and self._averager_endpoint is None:
  132. assert self.port is not None, "Averager is not running yet"
  133. self._averager_endpoint = f"{self.dht.get_visible_address()}:{self.port}"
  134. logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
  135. return self._averager_endpoint
  136. def __repr__(self):
  137. return f"{self.__class__.__name__}({self.endpoint})"
  138. def run(self):
  139. """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
  140. loop = switch_to_uvloop()
  141. # initialize asyncio synchronization primitives in this event loop
  142. pipe_awaiter = ThreadPoolExecutor(max_workers=1)
  143. async def _run():
  144. grpc.aio.init_grpc_aio()
  145. if self.listen:
  146. server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
  147. averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, server)
  148. found_port = server.add_insecure_port(self.listen_on)
  149. assert found_port != 0, f"Failed to listen to {self.listen_on}"
  150. self._port.value = found_port
  151. await server.start()
  152. else:
  153. logger.info(f"The averager running in an experimental client mode, please report any bugs.")
  154. self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
  155. client_mode=not self.listen)
  156. if self.listen:
  157. asyncio.create_task(self._declare_for_download_periodically())
  158. self._pending_group_assembled = asyncio.Event()
  159. self._pending_group_assembled.set()
  160. self.ready.set()
  161. while True:
  162. method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
  163. asyncio.create_task(getattr(self, method)(*args, **kwargs))
  164. loop.run_until_complete(_run())
  165. def run_in_background(self, await_ready=True, timeout=None):
  166. """
  167. Starts averager in a background process. if await_ready, this method will wait until background dht
  168. is ready to process incoming requests or for :timeout: seconds max.
  169. """
  170. self.start()
  171. if await_ready and not self.ready.wait(timeout=timeout):
  172. raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
  173. def shutdown(self) -> None:
  174. """ Shut down the averager process """
  175. # TODO notify peers before terminating
  176. if self._parent_pid != os.getpid() or self.is_alive():
  177. self._pipe.send(('_SHUTDOWN', None))
  178. self.terminate()
  179. else:
  180. logger.warning("DHT shutdown has no effect: the process is not alive")
  181. def __del__(self):
  182. if self._parent_pid != os.getpid() or self.is_alive():
  183. self.shutdown()
  184. def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
  185. allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
  186. """
  187. Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
  188. :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
  189. (this operation is known as all-gather). The gathered data will be available as the output of this function.
  190. :param weight: averaging weight for this peer, int or float, must be strictly positive
  191. :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
  192. within the specified timeout
  193. :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
  194. :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
  195. :returns: on success, update averaged_tensors and return group info; on failure, return None
  196. """
  197. assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
  198. future, _future = MPFuture.make_pair()
  199. gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
  200. self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
  201. allow_retries=allow_retries, timeout=timeout)))
  202. return future.result() if wait else future
  203. async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
  204. allow_retries: bool, timeout: Optional[float]):
  205. loop = asyncio.get_event_loop()
  206. start_time = get_dht_time()
  207. group_id = None
  208. while not future.done():
  209. try:
  210. self._pending_group_assembled.clear()
  211. data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
  212. group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
  213. if group_info is None:
  214. raise AllreduceException("Averaging step failed: could not find a group.")
  215. group_id = group_info.group_id
  216. allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
  217. self._running_groups[group_id] = allreduce_runner
  218. self._pending_group_assembled.set()
  219. await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
  220. await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
  221. # averaging is finished, exit the loop
  222. future.set_result(allreduce_runner.gathered)
  223. except (AllreduceException, MatchmakingException, AssertionError,
  224. asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
  225. time_elapsed = get_dht_time() - start_time
  226. if not allow_retries or (timeout is not None and timeout < time_elapsed):
  227. logger.warning(f"Averager caught {e}")
  228. future.set_result(None)
  229. else:
  230. logger.warning(f"Averager caught {e}, retrying")
  231. except Exception as e:
  232. future.set_exception(e)
  233. raise
  234. finally:
  235. _ = self._running_groups.pop(group_id, None)
  236. self._pending_group_assembled.set()
  237. async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
  238. """ Use a group description found by Matchmaking to form AllreduceRunner """
  239. try:
  240. weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
  241. user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
  242. # compute optimal part sizes from peer throughputs
  243. incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
  244. part_sizes = await asyncio.get_event_loop().run_in_executor(
  245. None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
  246. async with self.get_tensors_async() as averaged_tensors:
  247. return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
  248. ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
  249. weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
  250. except Exception as e:
  251. raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
  252. def update_tensors(self, allreduce_group: AllReduceRunner):
  253. """
  254. a private (extendable) method that applies changes from a finished allreduce to local tensors
  255. """
  256. assert allreduce_group.return_deltas and allreduce_group.future.done()
  257. averaging_deltas = allreduce_group.future.result()
  258. with torch.no_grad(), self.get_tensors() as local_tensors:
  259. assert len(local_tensors) == len(self._averaged_tensors)
  260. for tensor, update in zip(local_tensors, averaging_deltas):
  261. tensor.add_(update, alpha=self._averaging_alpha)
  262. self.last_updated = get_dht_time()
  263. @contextlib.contextmanager
  264. def get_tensors(self) -> Sequence[torch.Tensor]:
  265. """
  266. A contextmanager that gives user access to averaged tensors.
  267. It is guaranteed that the averager will not modify tensors while this context is active.
  268. Please do not modify the yielded tensors in-place after the context is released.
  269. """
  270. with self.lock_averaged_tensors:
  271. yield self._averaged_tensors
  272. self.last_updated = get_dht_time()
  273. @contextlib.asynccontextmanager
  274. async def get_tensors_async(self) -> Sequence[torch.Tensor]:
  275. """ Like get_tensors, but uses an asynchronous contextmanager """
  276. try:
  277. await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
  278. yield self._averaged_tensors
  279. finally:
  280. self.lock_averaged_tensors.release()
  281. async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
  282. ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
  283. """ accept or reject a join request from another averager; if accepted, run him through allreduce steps """
  284. async for response in self._matchmaking.rpc_join_group(request, context):
  285. yield response
  286. async def rpc_aggregate_part(self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
  287. ) -> AsyncIterator[averaging_pb2.AveragingData]:
  288. """ a groupmate sends us a part of his tensor; we should average it with other peers and return the result """
  289. request = await anext(stream)
  290. if request.group_id not in self._running_groups:
  291. # this handles a special case when leader accepted us to group AND began allreduce right away,
  292. # but his response with group_id was delayed and other peers got to us first
  293. await self._pending_group_assembled.wait()
  294. group = self._running_groups.get(request.group_id)
  295. if group is None:
  296. yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
  297. return
  298. async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
  299. yield message
  300. async def _declare_for_download_periodically(self):
  301. download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
  302. while True:
  303. asyncio.create_task(asyncio.wait_for(self.dht.store(
  304. download_key, subkey=self.endpoint, value=self.last_updated,
  305. expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
  306. timeout=self._matchmaking.averaging_expiration))
  307. await asyncio.sleep(self._matchmaking.averaging_expiration)
  308. async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
  309. ) -> AsyncIterator[averaging_pb2.DownloadData]:
  310. """
  311. Get the up-to-date trainer state from a peer.
  312. The state consists of two parts: (serialized_metadata, tensors)
  313. - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
  314. - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
  315. """
  316. chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
  317. metadata, tensors = await self._get_current_state_from_host_process()
  318. for tensor in tensors:
  319. for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
  320. if metadata is not None:
  321. yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
  322. metadata = None
  323. else:
  324. yield averaging_pb2.DownloadData(tensor_part=part)
  325. def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
  326. """
  327. Get current state and send it to a peer. executed in the host process. Meant to be overriden.
  328. :returns: a tuple of (small metadata, sequence of torch tensors)
  329. :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
  330. """
  331. with self.get_tensors() as tensors:
  332. return dict(group_key=self.get_group_bits()), tensors
  333. async def _get_current_state_from_host_process(self):
  334. """ Executed in the averager process inside rpc_download_state """
  335. future, _future = MPFuture.make_pair()
  336. self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
  337. return await future
  338. def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
  339. """
  340. Try to download the latest optimizer state one of the existing peer.
  341. :returns: on success, return a 2-tuple with (metadata, tensors), where
  342. - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
  343. - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
  344. The exact contents of both metadata and tensors are determined by get_current_state method
  345. """
  346. future, _future = MPFuture.make_pair()
  347. self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
  348. return future.result() if wait else future
  349. async def _load_state_from_peers(self, future: MPFuture):
  350. key_manager = self._matchmaking.group_key_manager
  351. peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
  352. peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
  353. if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
  354. if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
  355. logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
  356. future.set_result(None)
  357. return
  358. metadata = None
  359. for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
  360. if peer != self.endpoint:
  361. logger.info(f"Downloading parameters from peer {peer}")
  362. stream = None
  363. try:
  364. leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
  365. stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
  366. current_tensor_parts, tensors = [], []
  367. async for message in stream:
  368. if message.metadata:
  369. metadata = self.serializer.loads(message.metadata)
  370. if message.tensor_part.dtype and current_tensor_parts:
  371. # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
  372. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  373. current_tensor_parts = []
  374. current_tensor_parts.append(message.tensor_part)
  375. if current_tensor_parts:
  376. tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
  377. future.set_result((metadata, tensors))
  378. self.last_updated = get_dht_time()
  379. return
  380. except grpc.aio.AioRpcError as e:
  381. logger.info(f"Failed to download state from {peer} - {e}")
  382. finally:
  383. if stream is not None:
  384. await stream.code()
  385. else:
  386. logger.warning("Averager could not load state from peers: found no active peers.")
  387. future.set_result(None)
  388. def get_group_bits(self, wait: bool = True):
  389. """
  390. :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
  391. :returns: averager's current group key bits (without prefix)
  392. """
  393. future, _future = MPFuture.make_pair()
  394. self.pipe.send(('_get_group_bits', [], dict(future=_future)))
  395. return future.result() if wait else future
  396. async def _get_group_bits(self, future: MPFuture):
  397. future.set_result(self._matchmaking.group_key_manager.group_bits)
  398. def set_group_bits(self, group_bits: str, wait: bool = True):
  399. """
  400. :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
  401. :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
  402. """
  403. future, _future = MPFuture.make_pair()
  404. assert all(bit in '01' for bit in group_bits)
  405. self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
  406. return future.result() if wait else future
  407. async def _set_group_bits(self, group_bits: str, future: MPFuture):
  408. try:
  409. self._matchmaking.group_key_manager.group_bits = group_bits
  410. return future.set_result(None)
  411. except Exception as e:
  412. if not future.done():
  413. future.set_exception(e)
  414. def is_power_of_two(n):
  415. """ Check whether n is a power of 2 """
  416. return (n != 0) and (n & (n - 1) == 0)
  417. def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.connection.Connection,
  418. get_current_state_ref: weakref.WeakMethod):
  419. """
  420. Executed in the host process as a background thread. Fetches the averager state when asked by peers.
  421. :param serializer: a serializer with which to convert metadata into bytes
  422. :param pipe: DecentralizedAverager's control pipe (from host process side)
  423. :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
  424. """
  425. while True:
  426. trigger, future = pipe.recv()
  427. if trigger == '_SHUTDOWN':
  428. break
  429. assert trigger == '_TRIGGER_GET_CURRENT_STATE'
  430. try:
  431. get_current_state = get_current_state_ref()
  432. if get_current_state is None:
  433. break
  434. state_metadata, state_tensors = get_current_state()
  435. del get_current_state
  436. state_metadata = serializer.dumps(state_metadata)
  437. state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
  438. for tensor in state_tensors)
  439. # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
  440. future.set_result((state_metadata, state_tensors))
  441. except BaseException as e:
  442. future.set_exception(e)
  443. logger.warning(e)
  444. continue
  445. def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
  446. """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
  447. schema_dicts = [{field_name: str(field_value)
  448. for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
  449. for tensor in tensors]
  450. return DHTID.generate(source=schema_dicts).to_bytes()