|
@@ -7,6 +7,7 @@ import contextlib
|
|
|
import ctypes
|
|
import ctypes
|
|
|
import multiprocessing as mp
|
|
import multiprocessing as mp
|
|
|
import os
|
|
import os
|
|
|
|
|
+import random
|
|
|
import threading
|
|
import threading
|
|
|
import weakref
|
|
import weakref
|
|
|
from dataclasses import asdict
|
|
from dataclasses import asdict
|
|
@@ -16,6 +17,7 @@ import numpy as np
|
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
|
|
|
+from hivemind.averaging.control import AveragingStage, StepControl
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
@@ -28,10 +30,20 @@ from hivemind.compression import (
|
|
|
serialize_torch_tensor,
|
|
serialize_torch_tensor,
|
|
|
)
|
|
)
|
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.dht import DHT, DHTID
|
|
|
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
|
|
|
|
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
|
|
|
from hivemind.proto import averaging_pb2
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
|
|
|
|
|
+from hivemind.utils.asyncio import (
|
|
|
|
|
+ achain,
|
|
|
|
|
+ afirst,
|
|
|
|
|
+ aiter_with_timeout,
|
|
|
|
|
+ anext,
|
|
|
|
|
+ as_aiter,
|
|
|
|
|
+ azip,
|
|
|
|
|
+ enter_asynchronously,
|
|
|
|
|
+ switch_to_uvloop,
|
|
|
|
|
+)
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -54,16 +66,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:param prefix: a shared prefix for all group keys
|
|
:param prefix: a shared prefix for all group keys
|
|
|
:param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
|
|
:param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
|
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
|
- :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
|
|
|
|
|
- note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
|
|
|
|
|
|
|
+ :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
|
|
|
:param compression: optionally compress tensors with this compression algorithm before running all-reduce
|
|
:param compression: optionally compress tensors with this compression algorithm before running all-reduce
|
|
|
:param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
|
|
:param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
|
|
|
:param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
|
|
:param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
|
|
|
- :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
|
|
|
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
|
- :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
|
|
|
|
|
|
|
+ :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
|
|
|
:param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
|
|
:param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
|
|
|
:param bandwidth: if specified, this value represents the network bandwidth available to averager.
|
|
:param bandwidth: if specified, this value represents the network bandwidth available to averager.
|
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
@@ -75,6 +85,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
local tensors for averaging
|
|
local tensors for averaging
|
|
|
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
|
with averager.allow_state_sharing = True / False
|
|
with averager.allow_state_sharing = True / False
|
|
|
|
|
+ :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
|
|
|
|
|
+ :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
|
|
|
|
|
+ :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
|
|
|
|
|
+ this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
|
|
|
|
|
+ :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
|
|
|
|
|
+ previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
|
|
|
|
|
+ :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
|
|
|
|
|
+ from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
|
|
|
:param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
|
|
:param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
|
|
|
|
|
|
|
|
Example:
|
|
Example:
|
|
@@ -92,6 +110,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
|
|
_matchmaking: Matchmaking
|
|
_matchmaking: Matchmaking
|
|
|
_pending_group_assembled: asyncio.Event
|
|
_pending_group_assembled: asyncio.Event
|
|
|
|
|
+ _state_updated: asyncio.Event
|
|
|
|
|
+ _p2p: P2P
|
|
|
serializer = MSGPackSerializer
|
|
serializer = MSGPackSerializer
|
|
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
@@ -101,14 +121,18 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
*,
|
|
*,
|
|
|
start: bool,
|
|
start: bool,
|
|
|
prefix: str,
|
|
prefix: str,
|
|
|
- target_group_size: int,
|
|
|
|
|
|
|
+ target_group_size: Optional[int] = None,
|
|
|
min_group_size: int = 2,
|
|
min_group_size: int = 2,
|
|
|
initial_group_bits: str = "",
|
|
initial_group_bits: str = "",
|
|
|
- averaging_expiration: float = 15,
|
|
|
|
|
- request_timeout: float = 3,
|
|
|
|
|
|
|
+ averaging_expiration: Optional[float] = None,
|
|
|
|
|
+ min_matchmaking_time: float = 5.0,
|
|
|
|
|
+ request_timeout: float = 3.0,
|
|
|
averaging_alpha: float = 1.0,
|
|
averaging_alpha: float = 1.0,
|
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
|
allreduce_timeout: Optional[float] = None,
|
|
allreduce_timeout: Optional[float] = None,
|
|
|
|
|
+ next_chunk_timeout: Optional[float] = None,
|
|
|
|
|
+ sender_timeout: Optional[float] = None,
|
|
|
|
|
+ reducer_timeout: Optional[float] = None,
|
|
|
compression: CompressionBase = NoCompression(),
|
|
compression: CompressionBase = NoCompression(),
|
|
|
state_compression: CompressionBase = NoCompression(),
|
|
state_compression: CompressionBase = NoCompression(),
|
|
|
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
|
|
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
|
|
@@ -116,6 +140,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
min_vector_size: int = 0,
|
|
min_vector_size: int = 0,
|
|
|
auxiliary: bool = False,
|
|
auxiliary: bool = False,
|
|
|
allow_state_sharing: Optional[bool] = None,
|
|
allow_state_sharing: Optional[bool] = None,
|
|
|
|
|
+ declare_state_period: float = 30,
|
|
|
client_mode: Optional[bool] = None,
|
|
client_mode: Optional[bool] = None,
|
|
|
daemon: bool = True,
|
|
daemon: bool = True,
|
|
|
shutdown_timeout: float = 5,
|
|
shutdown_timeout: float = 5,
|
|
@@ -124,17 +149,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
assert bandwidth is None or (
|
|
assert bandwidth is None or (
|
|
|
bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
|
|
bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
|
|
|
), "bandwidth must be a non-negative float32"
|
|
), "bandwidth must be a non-negative float32"
|
|
|
- if not is_power_of_two(target_group_size):
|
|
|
|
|
- logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
|
|
|
assert all(bit in "01" for bit in initial_group_bits)
|
|
assert all(bit in "01" for bit in initial_group_bits)
|
|
|
assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
|
|
|
|
+ if averaging_expiration is not None:
|
|
|
|
|
+ logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
|
|
|
|
|
+ assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
|
|
|
|
|
+ min_matchmaking_time = averaging_expiration
|
|
|
|
|
+
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
self.dht = dht
|
|
|
self.prefix = prefix
|
|
self.prefix = prefix
|
|
|
|
|
|
|
|
if client_mode is None:
|
|
if client_mode is None:
|
|
|
client_mode = dht.client_mode
|
|
client_mode = dht.client_mode
|
|
|
|
|
+ if sender_timeout is None:
|
|
|
|
|
+ sender_timeout = next_chunk_timeout
|
|
|
|
|
+ if reducer_timeout is None:
|
|
|
|
|
+ reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
|
|
|
|
|
+
|
|
|
self.client_mode = client_mode
|
|
self.client_mode = client_mode
|
|
|
|
|
|
|
|
self._parent_pid = os.getpid()
|
|
self._parent_pid = os.getpid()
|
|
@@ -148,13 +181,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
|
|
self._averaged_tensors = tuple(averaged_tensors)
|
|
self._averaged_tensors = tuple(averaged_tensors)
|
|
|
self.lock_averaged_tensors = mp.Lock()
|
|
self.lock_averaged_tensors = mp.Lock()
|
|
|
- self.last_updated: DHTExpiration = -float("inf")
|
|
|
|
|
for tensor in self._averaged_tensors:
|
|
for tensor in self._averaged_tensors:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|
|
tensor.share_memory_()
|
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
self.shutdown_timeout = shutdown_timeout
|
|
|
|
|
+ self.next_chunk_timeout = next_chunk_timeout
|
|
|
self.bandwidth = bandwidth
|
|
self.bandwidth = bandwidth
|
|
|
|
|
|
|
|
self.matchmaking_kwargs = dict(
|
|
self.matchmaking_kwargs = dict(
|
|
@@ -163,13 +196,15 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
initial_group_bits=initial_group_bits,
|
|
initial_group_bits=initial_group_bits,
|
|
|
target_group_size=target_group_size,
|
|
target_group_size=target_group_size,
|
|
|
min_group_size=min_group_size,
|
|
min_group_size=min_group_size,
|
|
|
- averaging_expiration=averaging_expiration,
|
|
|
|
|
request_timeout=request_timeout,
|
|
request_timeout=request_timeout,
|
|
|
|
|
+ min_matchmaking_time=min_matchmaking_time,
|
|
|
)
|
|
)
|
|
|
self.allreduce_kwargs = dict(
|
|
self.allreduce_kwargs = dict(
|
|
|
compression=compression,
|
|
compression=compression,
|
|
|
part_size_bytes=part_size_bytes,
|
|
part_size_bytes=part_size_bytes,
|
|
|
min_vector_size=min_vector_size,
|
|
min_vector_size=min_vector_size,
|
|
|
|
|
+ sender_timeout=sender_timeout,
|
|
|
|
|
+ reducer_timeout=reducer_timeout,
|
|
|
)
|
|
)
|
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
@@ -177,9 +212,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with daemon
|
|
|
|
|
|
|
|
self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
|
|
+ self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
|
|
|
|
|
+
|
|
|
if allow_state_sharing is None:
|
|
if allow_state_sharing is None:
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
|
|
+ self.declare_state_period = declare_state_period
|
|
|
self.state_compression = state_compression
|
|
self.state_compression = state_compression
|
|
|
self.tensor_infos = tensor_infos
|
|
self.tensor_infos = tensor_infos
|
|
|
|
|
|
|
@@ -202,9 +240,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
@allow_state_sharing.setter
|
|
@allow_state_sharing.setter
|
|
|
def allow_state_sharing(self, value: bool):
|
|
def allow_state_sharing(self, value: bool):
|
|
|
if value and self.client_mode:
|
|
if value and self.client_mode:
|
|
|
- raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
|
|
|
|
|
|
|
+ raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
|
|
|
else:
|
|
else:
|
|
|
- self._allow_state_sharing.value = value
|
|
|
|
|
|
|
+ old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
|
|
|
|
|
+ if value != old_value:
|
|
|
|
|
+ self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def state_sharing_priority(self) -> float:
|
|
|
|
|
+ """Others will preferentially downloading state from peers with highest priority."""
|
|
|
|
|
+ return float(self._state_sharing_priority.value)
|
|
|
|
|
+
|
|
|
|
|
+ @state_sharing_priority.setter
|
|
|
|
|
+ def state_sharing_priority(self, value: float):
|
|
|
|
|
+ if value and self.client_mode:
|
|
|
|
|
+ raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
|
|
|
|
|
+ else:
|
|
|
|
|
+ old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
|
|
|
|
|
+ if self.allow_state_sharing and value != old_value:
|
|
|
|
|
+ self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
|
|
|
|
|
+
|
|
|
|
|
+ async def _trigger_declare_load_state(self):
|
|
|
|
|
+ # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
|
|
|
|
|
+ self._state_updated.set()
|
|
|
|
|
|
|
|
@property
|
|
@property
|
|
|
def peer_id(self) -> PeerID:
|
|
def peer_id(self) -> PeerID:
|
|
@@ -238,7 +296,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if not self.client_mode:
|
|
if not self.client_mode:
|
|
|
await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
else:
|
|
else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
|
|
|
|
+ logger.debug("The averager is running in client mode")
|
|
|
|
|
|
|
|
self._matchmaking = Matchmaking(
|
|
self._matchmaking = Matchmaking(
|
|
|
self._p2p,
|
|
self._p2p,
|
|
@@ -250,6 +308,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if not self.client_mode:
|
|
if not self.client_mode:
|
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
|
|
|
|
+ self._state_updated = asyncio.Event()
|
|
|
self._pending_group_assembled = asyncio.Event()
|
|
self._pending_group_assembled = asyncio.Event()
|
|
|
self._pending_group_assembled.set()
|
|
self._pending_group_assembled.set()
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -294,20 +353,20 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def shutdown(self) -> None:
|
|
def shutdown(self) -> None:
|
|
|
"""Shut down the averager process"""
|
|
"""Shut down the averager process"""
|
|
|
if self.is_alive():
|
|
if self.is_alive():
|
|
|
- self._outer_pipe.send(("_shutdown", [None], {})) # shut down the daemon process
|
|
|
|
|
|
|
+ self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {})) # shut down the daemon process
|
|
|
self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
|
|
self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
|
|
|
self.join(self.shutdown_timeout)
|
|
self.join(self.shutdown_timeout)
|
|
|
if self.is_alive():
|
|
if self.is_alive():
|
|
|
- logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
|
|
|
|
|
|
|
+ logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
|
|
|
self.terminate()
|
|
self.terminate()
|
|
|
else:
|
|
else:
|
|
|
logger.exception("Averager shutdown has no effect: the process is already not alive")
|
|
logger.exception("Averager shutdown has no effect: the process is already not alive")
|
|
|
|
|
|
|
|
- async def _shutdown(self, timeout: Optional[float] = None) -> None:
|
|
|
|
|
|
|
+ async def _shutdown(self, timeout: Optional[float]) -> None:
|
|
|
remaining_tasks = set()
|
|
remaining_tasks = set()
|
|
|
for group in self._running_groups.values():
|
|
for group in self._running_groups.values():
|
|
|
remaining_tasks.update(group.finalize(cancel=True))
|
|
remaining_tasks.update(group.finalize(cancel=True))
|
|
|
- await asyncio.gather(*remaining_tasks)
|
|
|
|
|
|
|
+ await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
|
|
|
|
|
|
|
|
def __del__(self):
|
|
def __del__(self):
|
|
|
if self._parent_pid == os.getpid() and self.is_alive():
|
|
if self._parent_pid == os.getpid() and self.is_alive():
|
|
@@ -316,67 +375,96 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def step(
|
|
def step(
|
|
|
self,
|
|
self,
|
|
|
gather: Optional[GatheredData] = None,
|
|
gather: Optional[GatheredData] = None,
|
|
|
|
|
+ scheduled_time: Optional[DHTExpiration] = None,
|
|
|
weight: Optional[float] = None,
|
|
weight: Optional[float] = None,
|
|
|
timeout: Optional[float] = None,
|
|
timeout: Optional[float] = None,
|
|
|
allow_retries: bool = True,
|
|
allow_retries: bool = True,
|
|
|
|
|
+ require_trigger: bool = False,
|
|
|
wait: bool = True,
|
|
wait: bool = True,
|
|
|
- ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
|
|
|
|
|
|
|
+ ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
|
|
|
"""
|
|
"""
|
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
|
|
|
|
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
|
|
|
+ :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
|
|
|
|
|
+ By default, schedule all-reduce current time plus min_matchmaking_time seconds
|
|
|
:param weight: averaging weight for this peer, int or float, must be strictly positive
|
|
:param weight: averaging weight for this peer, int or float, must be strictly positive
|
|
|
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
|
within the specified timeout
|
|
within the specified timeout
|
|
|
- :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
|
|
|
|
|
- :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
|
|
|
|
+ :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
|
|
|
|
|
+ :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
|
|
|
|
|
+ :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
"""
|
|
|
if self.mode == AveragingMode.AUX and weight is not None:
|
|
if self.mode == AveragingMode.AUX and weight is not None:
|
|
|
- logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
|
|
|
|
+ logger.warning("Averager is running in auxiliary mode, weight is unused")
|
|
|
|
|
+ if scheduled_time is None:
|
|
|
|
|
+ scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
|
|
|
if weight is None:
|
|
if weight is None:
|
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
|
|
|
+ deadline = get_dht_time() + timeout if timeout is not None else float("inf")
|
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
-
|
|
|
|
|
- future = MPFuture()
|
|
|
|
|
- gather_binary = self.serializer.dumps(
|
|
|
|
|
- gather
|
|
|
|
|
- ) # serialize here to avoid loading modules in the averager process
|
|
|
|
|
- self._outer_pipe.send(
|
|
|
|
|
- (
|
|
|
|
|
- "_step",
|
|
|
|
|
- [],
|
|
|
|
|
- dict(
|
|
|
|
|
- future=future,
|
|
|
|
|
- gather_binary=gather_binary,
|
|
|
|
|
- weight=weight,
|
|
|
|
|
- allow_retries=allow_retries,
|
|
|
|
|
- timeout=timeout,
|
|
|
|
|
- ),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
|
|
|
|
|
+ assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
|
|
|
|
|
+
|
|
|
|
|
+ user_data_for_gather = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
|
|
|
|
|
+ data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
|
|
|
|
|
+ step = StepControl(
|
|
|
|
|
+ scheduled_time=scheduled_time,
|
|
|
|
|
+ deadline=deadline,
|
|
|
|
|
+ allow_retries=allow_retries,
|
|
|
|
|
+ weight=weight,
|
|
|
|
|
+ data_for_gather=data_for_gather,
|
|
|
)
|
|
)
|
|
|
- return future.result() if wait else future
|
|
|
|
|
|
|
|
|
|
- async def _step(
|
|
|
|
|
- self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
|
|
|
|
|
- ):
|
|
|
|
|
- start_time = get_dht_time()
|
|
|
|
|
|
|
+ future_for_init = MPFuture()
|
|
|
|
|
+ self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
|
|
|
|
|
+ step.attach(*future_for_init.result())
|
|
|
|
|
|
|
|
|
|
+ if not require_trigger:
|
|
|
|
|
+ step.allow_allreduce()
|
|
|
|
|
+ return step.result() if wait else step
|
|
|
|
|
+
|
|
|
|
|
+ async def _step(self, *, step: StepControl, future_for_init: MPFuture):
|
|
|
try:
|
|
try:
|
|
|
- while not future.done():
|
|
|
|
|
|
|
+ trigger, cancel = MPFuture(), MPFuture()
|
|
|
|
|
+ step.attach(trigger, cancel)
|
|
|
|
|
+ future_for_init.set_result((trigger, cancel))
|
|
|
|
|
+
|
|
|
|
|
+ async def find_peers_or_notify_cancel():
|
|
|
|
|
+ group_info = await self._matchmaking.look_for_group(step)
|
|
|
|
|
+ if not step.triggered:
|
|
|
|
|
+ step.stage = AveragingStage.AWAITING_TRIGGER
|
|
|
|
|
+ await step.wait_for_trigger()
|
|
|
|
|
+ return group_info
|
|
|
|
|
+
|
|
|
|
|
+ while not step.done():
|
|
|
try:
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
self._pending_group_assembled.clear()
|
|
|
- data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
|
|
|
|
|
- group_info = await self._matchmaking.look_for_group(
|
|
|
|
|
- timeout=timeout, data_for_gather=data_for_gather
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
|
|
+ matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
|
|
|
|
|
+ check_cancel_task = asyncio.create_task(step.wait_for_cancel())
|
|
|
|
|
+
|
|
|
|
|
+ await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
|
+ if step.cancelled():
|
|
|
|
|
+ matchmaking_task.cancel()
|
|
|
|
|
+ raise asyncio.CancelledError()
|
|
|
|
|
+ else:
|
|
|
|
|
+ check_cancel_task.cancel()
|
|
|
|
|
+
|
|
|
|
|
+ group_info = await matchmaking_task
|
|
|
|
|
+
|
|
|
if group_info is None:
|
|
if group_info is None:
|
|
|
- raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
|
|
|
+ raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
|
|
|
- future.set_result(
|
|
|
|
|
|
|
+ step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|
|
+
|
|
|
|
|
+ step.set_result(
|
|
|
await asyncio.wait_for(
|
|
await asyncio.wait_for(
|
|
|
- self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
|
|
|
|
|
|
|
+ self._run_allreduce(
|
|
|
|
|
+ group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
|
|
+ ),
|
|
|
timeout=self._allreduce_timeout,
|
|
timeout=self._allreduce_timeout,
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
@@ -390,21 +478,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.CancelledError,
|
|
asyncio.CancelledError,
|
|
|
asyncio.InvalidStateError,
|
|
asyncio.InvalidStateError,
|
|
|
P2PHandlerError,
|
|
P2PHandlerError,
|
|
|
|
|
+ DispatchFailure,
|
|
|
|
|
+ ControlFailure,
|
|
|
) as e:
|
|
) as e:
|
|
|
- time_elapsed = get_dht_time() - start_time
|
|
|
|
|
- if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
|
|
- logger.exception(f"Averager caught {repr(e)}")
|
|
|
|
|
- future.set_exception(e)
|
|
|
|
|
|
|
+ if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
|
|
+ if not step.cancelled():
|
|
|
|
|
+ logger.exception(e)
|
|
|
|
|
+ if not step.done():
|
|
|
|
|
+ step.set_exception(e)
|
|
|
else:
|
|
else:
|
|
|
- logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
|
+ logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
except BaseException as e:
|
|
|
- if not future.done():
|
|
|
|
|
- future.set_exception(e)
|
|
|
|
|
|
|
+ if not step.done():
|
|
|
|
|
+ step.set_exception(e)
|
|
|
raise
|
|
raise
|
|
|
finally:
|
|
finally:
|
|
|
- if not future.done():
|
|
|
|
|
- future.set_exception(
|
|
|
|
|
|
|
+ step.stage = AveragingStage.FINISHED
|
|
|
|
|
+ if not step.done():
|
|
|
|
|
+ step.set_exception(
|
|
|
RuntimeError(
|
|
RuntimeError(
|
|
|
"Internal sanity check failed: averager.step left future pending."
|
|
"Internal sanity check failed: averager.step left future pending."
|
|
|
" Please report this to hivemind issues."
|
|
" Please report this to hivemind issues."
|
|
@@ -414,8 +506,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
try:
|
|
try:
|
|
|
- weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
|
|
- user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
|
|
|
|
|
|
|
+ bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
|
|
+ user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
|
|
# compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
# compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
@@ -426,7 +518,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
|
|
None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- async with self.get_tensors_async() as local_tensors:
|
|
|
|
|
|
|
+ async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
|
allreduce = AllReduceRunner(
|
|
allreduce = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
servicer_type=type(self),
|
|
@@ -435,26 +527,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
tensors=local_tensors,
|
|
tensors=local_tensors,
|
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
ordered_peer_ids=group_info.peer_ids,
|
|
|
peer_fractions=peer_fractions,
|
|
peer_fractions=peer_fractions,
|
|
|
- weights=weights,
|
|
|
|
|
gathered=user_gathered,
|
|
gathered=user_gathered,
|
|
|
modes=modes,
|
|
modes=modes,
|
|
|
**kwargs,
|
|
**kwargs,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
-
|
|
|
|
|
- # actually run all-reduce
|
|
|
|
|
- averaging_outputs = [output async for output in allreduce]
|
|
|
|
|
-
|
|
|
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
|
|
- for tensor, update in zip(local_tensors, averaging_outputs):
|
|
|
|
|
|
|
+ iter_results = allreduce.run()
|
|
|
|
|
+ async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
|
|
|
|
|
+ # all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self.last_updated = get_dht_time()
|
|
|
|
|
|
|
+ self._state_updated.set()
|
|
|
|
|
+
|
|
|
|
|
+ else:
|
|
|
|
|
+ async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
|
|
+ raise ValueError("aux peers should not receive averaged tensors")
|
|
|
|
|
|
|
|
return allreduce.gathered
|
|
return allreduce.gathered
|
|
|
except BaseException as e:
|
|
except BaseException as e:
|
|
|
- logger.exception(e)
|
|
|
|
|
|
|
+ if isinstance(e, Exception):
|
|
|
|
|
+ logger.exception(e)
|
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
@contextlib.contextmanager
|
|
@@ -477,16 +570,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
"""
|
|
|
with self.lock_averaged_tensors:
|
|
with self.lock_averaged_tensors:
|
|
|
yield self._averaged_tensors
|
|
yield self._averaged_tensors
|
|
|
- self.last_updated = get_dht_time()
|
|
|
|
|
-
|
|
|
|
|
- @contextlib.asynccontextmanager
|
|
|
|
|
- async def get_tensors_async(self) -> Sequence[torch.Tensor]:
|
|
|
|
|
- """Like get_tensors, but uses an asynchronous contextmanager"""
|
|
|
|
|
- try:
|
|
|
|
|
- await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
|
|
|
|
|
- yield self._averaged_tensors
|
|
|
|
|
- finally:
|
|
|
|
|
- self.lock_averaged_tensors.release()
|
|
|
|
|
|
|
|
|
|
async def rpc_join_group(
|
|
async def rpc_join_group(
|
|
|
self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
@@ -515,21 +598,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
|
|
async def _declare_for_download_periodically(self):
|
|
async def _declare_for_download_periodically(self):
|
|
|
download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
|
|
download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
|
|
|
|
|
+ sharing_was_allowed = self.allow_state_sharing
|
|
|
while True:
|
|
while True:
|
|
|
- if self.allow_state_sharing:
|
|
|
|
|
|
|
+ expiration_time = get_dht_time() + self.declare_state_period
|
|
|
|
|
+ if self.allow_state_sharing or sharing_was_allowed:
|
|
|
|
|
+ # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
|
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
|
asyncio.wait_for(
|
|
asyncio.wait_for(
|
|
|
self.dht.store(
|
|
self.dht.store(
|
|
|
download_key,
|
|
download_key,
|
|
|
subkey=self.peer_id.to_bytes(),
|
|
subkey=self.peer_id.to_bytes(),
|
|
|
- value=self.last_updated,
|
|
|
|
|
- expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
|
|
|
|
|
|
|
+ value=self.state_sharing_priority if self.allow_state_sharing else None,
|
|
|
|
|
+ expiration_time=expiration_time,
|
|
|
return_future=True,
|
|
return_future=True,
|
|
|
),
|
|
),
|
|
|
- timeout=self._matchmaking.averaging_expiration,
|
|
|
|
|
|
|
+ timeout=expiration_time - get_dht_time(),
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
- await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
|
|
|
|
+ sharing_was_allowed = self.allow_state_sharing
|
|
|
|
|
+
|
|
|
|
|
+ # report again either in state_declare_period or after the field was changed by the user
|
|
|
|
|
+ self._state_updated.clear()
|
|
|
|
|
+ try:
|
|
|
|
|
+ await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
|
|
|
|
|
+ except asyncio.TimeoutError:
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
async def rpc_download_state(
|
|
async def rpc_download_state(
|
|
|
self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
|
|
self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
|
|
@@ -584,21 +677,23 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
The exact contents of both metadata and tensors are determined by get_current_state method
|
|
The exact contents of both metadata and tensors are determined by get_current_state method
|
|
|
"""
|
|
"""
|
|
|
future = MPFuture()
|
|
future = MPFuture()
|
|
|
- self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
|
|
|
|
+ self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
|
|
|
return future.result(timeout=timeout) if wait else future
|
|
return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
|
|
- async def _load_state_from_peers(self, future: MPFuture):
|
|
|
|
|
|
|
+ async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
|
|
|
|
|
+ if timeout is not None:
|
|
|
|
|
+ timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
|
|
|
try:
|
|
try:
|
|
|
key_manager = self._matchmaking.group_key_manager
|
|
key_manager = self._matchmaking.group_key_manager
|
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
|
peer_priority = {
|
|
peer_priority = {
|
|
|
- PeerID(peer_id): float(info.value)
|
|
|
|
|
|
|
+ PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
|
|
|
for peer_id, info in peer_priority.items()
|
|
for peer_id, info in peer_priority.items()
|
|
|
if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
|
|
if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
|
|
if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
|
|
|
- logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
|
|
|
|
|
|
|
+ logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
|
|
|
future.set_result(None)
|
|
future.set_result(None)
|
|
|
return
|
|
return
|
|
|
|
|
|
|
@@ -608,10 +703,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
logger.info(f"Downloading parameters from peer {peer}")
|
|
logger.info(f"Downloading parameters from peer {peer}")
|
|
|
try:
|
|
try:
|
|
|
stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
|
|
stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
|
|
|
- stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
|
|
|
|
+ stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
current_tensor_parts, tensors = [], []
|
|
|
|
|
|
|
|
- async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
|
|
|
|
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=timeout):
|
|
|
if message.metadata:
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
@@ -623,12 +718,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
|
|
|
|
|
if not metadata:
|
|
if not metadata:
|
|
|
- logger.debug(f"Peer {peer} did not send its state.")
|
|
|
|
|
|
|
+ logger.debug(f"Peer {peer} did not send its state")
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
|
future.set_result((metadata, tensors))
|
|
future.set_result((metadata, tensors))
|
|
|
- self.last_updated = get_dht_time()
|
|
|
|
|
return
|
|
return
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.exception(f"Failed to download state from {peer} - {repr(e)}")
|
|
logger.exception(f"Failed to download state from {peer} - {repr(e)}")
|
|
@@ -668,11 +762,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def is_power_of_two(n):
|
|
|
|
|
- """Check whether n is a power of 2"""
|
|
|
|
|
- return (n != 0) and (n & (n - 1) == 0)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
def _background_thread_fetch_current_state(
|
|
def _background_thread_fetch_current_state(
|
|
|
serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
|
|
serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
|
|
|
):
|
|
):
|