|
@@ -20,7 +20,8 @@ import torch
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.dht import DHT, DHTID
|
|
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
|
|
|
|
|
|
+from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
|
|
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.client.averaging.group_info import GroupInfo
|
|
from hivemind.client.averaging.group_info import GroupInfo
|
|
@@ -34,9 +35,8 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
|
|
|
|
|
|
# flavour types
|
|
# flavour types
|
|
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
-DataForGather = Any
|
|
|
|
|
|
+GatheredData = Any
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
-DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
|
|
|
|
|
|
|
|
|
|
|
|
class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
@@ -61,7 +61,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
:note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
|
|
:note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
|
|
- :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
|
|
|
|
|
|
+ :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
|
|
:param throughput: if specified, this value represents the network bandwidth available to averager.
|
|
:param throughput: if specified, this value represents the network bandwidth available to averager.
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
If throughput == 0, averager will rely on its groupmates to do all the averaging.
|
|
If throughput == 0, averager will rely on its groupmates to do all the averaging.
|
|
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
|
|
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
:param kwargs: extra parameters forwarded to grpc.aio.server
|
|
:param kwargs: extra parameters forwarded to grpc.aio.server
|
|
|
|
+ :param auxiliary: if this flag is specified, averager.step will only assist others without sending
|
|
|
|
+ local tensors for averaging
|
|
|
|
+ :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
|
|
+ with averager.allow_state_sharing = True / False
|
|
|
|
|
|
Example:
|
|
Example:
|
|
|
|
|
|
@@ -90,10 +94,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
|
|
def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
|
|
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
|
|
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
|
|
- averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
|
|
|
|
- allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
|
|
|
|
|
|
+ averaging_expiration: float = 15, request_timeout: float = 3, averaging_alpha: float = 1.0,
|
|
|
|
+ part_size_bytes: int = DEFAULT_PART_SIZE_BYTES, allreduce_timeout: Optional[float] = None,
|
|
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
throughput: Optional[float] = None, min_vector_size: int = 0,
|
|
throughput: Optional[float] = None, min_vector_size: int = 0,
|
|
|
|
+ auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
|
|
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
|
|
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
|
|
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
|
|
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
|
|
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
|
|
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
|
|
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
if not is_power_of_two(target_group_size):
|
|
if not is_power_of_two(target_group_size):
|
|
logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
|
|
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
|
|
|
|
+ assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
self.dht = dht
|
|
self.dht = dht
|
|
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
|
|
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
|
|
|
|
+ if not self.listen:
|
|
|
|
+ self.mode = AveragingMode.CLIENT
|
|
|
|
+ elif auxiliary:
|
|
|
|
+ self.mode = AveragingMode.AUX
|
|
|
|
+ else:
|
|
|
|
+ self.mode = AveragingMode.NODE
|
|
|
|
+
|
|
self.channel_options = channel_options
|
|
self.channel_options = channel_options
|
|
self.daemon = daemon
|
|
self.daemon = daemon
|
|
|
|
|
|
@@ -122,13 +135,17 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
self.matchmaking_kwargs = dict(
|
|
self.matchmaking_kwargs = dict(
|
|
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
|
|
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
|
|
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
|
|
min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
|
|
- self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
|
|
|
|
|
|
+ self.allreduce_kwargs = dict(compression_type=compression_type, part_size_bytes=part_size_bytes,
|
|
min_vector_size=min_vector_size)
|
|
min_vector_size=min_vector_size)
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
|
|
|
self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
|
|
self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
|
|
+
|
|
|
|
+ self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
|
+ self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
|
|
|
|
+
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
if not self.listen:
|
|
if not self.listen:
|
|
self._averager_endpoint = f'client::{uuid.uuid4()}'
|
|
self._averager_endpoint = f'client::{uuid.uuid4()}'
|
|
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
def port(self) -> Optional[Port]:
|
|
def port(self) -> Optional[Port]:
|
|
return self._port.value if self._port.value != 0 else None
|
|
return self._port.value if self._port.value != 0 else None
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def allow_state_sharing(self) -> bool:
|
|
|
|
+ """ if set to True, other peers can download this peer's state """
|
|
|
|
+ return bool(self._allow_state_sharing.value)
|
|
|
|
+
|
|
|
|
+ @allow_state_sharing.setter
|
|
|
|
+ def allow_state_sharing(self, value: bool):
|
|
|
|
+ if value is True and not self.listen:
|
|
|
|
+ logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
|
|
|
|
+ else:
|
|
|
|
+ self._allow_state_sharing.value = value
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def endpoint(self) -> Optional[Endpoint]:
|
|
def endpoint(self) -> Optional[Endpoint]:
|
|
if self.listen and self._averager_endpoint is None:
|
|
if self.listen and self._averager_endpoint is None:
|
|
@@ -222,8 +251,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
if self._parent_pid != os.getpid() or self.is_alive():
|
|
if self._parent_pid != os.getpid() or self.is_alive():
|
|
self.shutdown()
|
|
self.shutdown()
|
|
|
|
|
|
- def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
|
|
|
|
- allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
|
|
|
|
|
|
+ def step(self, gather: Optional[GatheredData] = None, weight: Optional[float] = None,
|
|
|
|
+ timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True
|
|
|
|
+ ) -> Union[Optional[Dict[Endpoint, GatheredData]], MPFuture]:
|
|
"""
|
|
"""
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
|
|
|
|
@@ -236,7 +266,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
"""
|
|
"""
|
|
- assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
|
|
|
+ if self.mode == AveragingMode.AUX and weight is not None:
|
|
|
|
+ logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
|
+ if weight is None:
|
|
|
|
+ weight = float(self.mode != AveragingMode.AUX)
|
|
|
|
+ assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
|
+
|
|
future, _future = MPFuture.make_pair()
|
|
future, _future = MPFuture.make_pair()
|
|
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
|
|
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
|
|
@@ -245,28 +280,21 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
|
|
async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
|
|
allow_retries: bool, timeout: Optional[float]):
|
|
allow_retries: bool, timeout: Optional[float]):
|
|
- loop = asyncio.get_event_loop()
|
|
|
|
start_time = get_dht_time()
|
|
start_time = get_dht_time()
|
|
- group_id = None
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
while not future.done():
|
|
while not future.done():
|
|
try:
|
|
try:
|
|
self._pending_group_assembled.clear()
|
|
self._pending_group_assembled.clear()
|
|
- data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
|
|
|
|
|
|
+ data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary])
|
|
group_info = await self._matchmaking.look_for_group(timeout=timeout,
|
|
group_info = await self._matchmaking.look_for_group(timeout=timeout,
|
|
data_for_gather=data_for_gather)
|
|
data_for_gather=data_for_gather)
|
|
if group_info is None:
|
|
if group_info is None:
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
- group_id = group_info.group_id
|
|
|
|
- allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
|
|
|
|
- self._running_groups[group_id] = allreduce_runner
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
- await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
|
|
|
|
- await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
|
|
|
|
|
- # averaging is finished, exit the loop
|
|
|
|
- future.set_result(allreduce_runner.gathered)
|
|
|
|
|
|
+ future.set_result(await asyncio.wait_for(
|
|
|
|
+ self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout))
|
|
|
|
+ # averaging is finished, loop will now exit
|
|
|
|
|
|
except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
|
|
except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
|
|
asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
|
|
asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
|
|
@@ -277,10 +305,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
else:
|
|
else:
|
|
logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
- finally:
|
|
|
|
- _ = self._running_groups.pop(group_id, None)
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
-
|
|
|
|
except BaseException as e:
|
|
except BaseException as e:
|
|
if not future.done():
|
|
if not future.done():
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
@@ -290,35 +314,51 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
|
|
future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
|
|
" Please report this to hivemind issues."))
|
|
" Please report this to hivemind issues."))
|
|
|
|
|
|
- async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
|
|
- """ Use a group description found by Matchmaking to form AllreduceRunner """
|
|
|
|
|
|
+ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
|
+ """ Run All-Reduce in a given group and update tensors in place, return gathered metadata """
|
|
try:
|
|
try:
|
|
- weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
|
|
|
+ weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
|
|
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
|
|
|
|
+ modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
- # compute optimal part sizes from peer throughputs
|
|
|
|
- incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
|
|
|
|
- part_sizes = await asyncio.get_event_loop().run_in_executor(
|
|
|
|
|
|
+ # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
|
|
|
|
+ incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0
|
|
|
|
+ for thr, mode in zip(throughputs, modes)]
|
|
|
|
+ peer_fractions = await asyncio.get_event_loop().run_in_executor(
|
|
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
|
|
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
|
|
- async with self.get_tensors_async() as averaged_tensors:
|
|
|
|
- return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
|
|
|
|
- ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
|
|
|
|
- weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
|
|
|
|
- except Exception as e:
|
|
|
|
- raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
|
|
|
|
|
|
|
|
- def update_tensors(self, allreduce_group: AllReduceRunner):
|
|
|
|
- """
|
|
|
|
- a private (extendable) method that applies changes from a finished allreduce to local tensors
|
|
|
|
- """
|
|
|
|
- assert allreduce_group.return_deltas and allreduce_group.future.done()
|
|
|
|
- averaging_deltas = allreduce_group.future.result()
|
|
|
|
|
|
+ async with self.get_tensors_async() as local_tensors:
|
|
|
|
+ allreduce = AllReduceRunner(
|
|
|
|
+ group_id=group_info.group_id, tensors=local_tensors, endpoint=self.endpoint,
|
|
|
|
+ ordered_group_endpoints=group_info.endpoints, peer_fractions=peer_fractions, weights=weights,
|
|
|
|
+ gathered=user_gathered, modes=modes, **kwargs)
|
|
|
|
|
|
- with torch.no_grad(), self.get_tensors() as local_tensors:
|
|
|
|
- assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
|
- for tensor, update in zip(local_tensors, averaging_deltas):
|
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
- self.last_updated = get_dht_time()
|
|
|
|
|
|
+ with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
|
+
|
|
|
|
+ # actually run all-reduce
|
|
|
|
+ averaging_outputs = [output async for output in allreduce]
|
|
|
|
+
|
|
|
|
+ if modes[group_info.endpoints.index(self.endpoint)] != AveragingMode.AUX:
|
|
|
|
+ assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
|
+ for tensor, update in zip(local_tensors, averaging_outputs):
|
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
|
+
|
|
|
|
+ return allreduce.gathered
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.exception(e)
|
|
|
|
+ raise MatchmakingException(f"Unable to run All-Reduce: {e}")
|
|
|
|
+
|
|
|
|
+ @contextlib.contextmanager
|
|
|
|
+ def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
|
|
|
|
+ """ registers a given all-reduce runner to listen for incoming connections """
|
|
|
|
+ try:
|
|
|
|
+ self._running_groups[group_id] = allreduce
|
|
|
|
+ self._pending_group_assembled.set()
|
|
|
|
+ yield
|
|
|
|
+ finally:
|
|
|
|
+ self._running_groups.pop(group_id, None)
|
|
|
|
+ self._pending_group_assembled.set()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
@contextlib.contextmanager
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
@@ -366,10 +406,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
async def _declare_for_download_periodically(self):
|
|
async def _declare_for_download_periodically(self):
|
|
download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
|
|
download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
|
|
while True:
|
|
while True:
|
|
- asyncio.create_task(asyncio.wait_for(self.dht.store(
|
|
|
|
- download_key, subkey=self.endpoint, value=self.last_updated,
|
|
|
|
- expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
|
|
|
|
- timeout=self._matchmaking.averaging_expiration))
|
|
|
|
|
|
+ if self.allow_state_sharing:
|
|
|
|
+ asyncio.create_task(asyncio.wait_for(self.dht.store(
|
|
|
|
+ download_key, subkey=self.endpoint, value=self.last_updated,
|
|
|
|
+ expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
|
|
|
|
+ timeout=self._matchmaking.averaging_expiration))
|
|
await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
|
|
|
async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
|
|
async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
|
|
@@ -381,11 +422,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
"""
|
|
"""
|
|
- chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
|
|
|
|
|
|
+ if not self.allow_state_sharing:
|
|
|
|
+ return # deny request and direct peer to the next prospective averager
|
|
metadata, tensors = await self._get_current_state_from_host_process()
|
|
metadata, tensors = await self._get_current_state_from_host_process()
|
|
|
|
|
|
for tensor in tensors:
|
|
for tensor in tensors:
|
|
- for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
|
|
|
|
|
|
+ for part in split_for_streaming(serialize_torch_tensor(tensor)):
|
|
if metadata is not None:
|
|
if metadata is not None:
|
|
yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
metadata = None
|
|
metadata = None
|
|
@@ -452,6 +494,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
current_tensor_parts.append(message.tensor_part)
|
|
current_tensor_parts.append(message.tensor_part)
|
|
if current_tensor_parts:
|
|
if current_tensor_parts:
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
|
+
|
|
|
|
+ if not metadata:
|
|
|
|
+ logger.debug(f"Peer {peer} did not send its state.")
|
|
|
|
+ continue
|
|
|
|
+
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
future.set_result((metadata, tensors))
|
|
future.set_result((metadata, tensors))
|
|
self.last_updated = get_dht_time()
|
|
self.last_updated = get_dht_time()
|
|
@@ -512,7 +559,12 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
|
|
:param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
|
|
:param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
|
|
"""
|
|
"""
|
|
while True:
|
|
while True:
|
|
- trigger, future = pipe.recv()
|
|
|
|
|
|
+ try:
|
|
|
|
+ trigger, future = pipe.recv()
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ logger.debug(f"Averager background thread finished: {repr(e)}")
|
|
|
|
+ break
|
|
|
|
+
|
|
if trigger == '_SHUTDOWN':
|
|
if trigger == '_SHUTDOWN':
|
|
break
|
|
break
|
|
|
|
|