|
@@ -20,7 +20,7 @@ import torch
|
|
|
import numpy as np
|
|
|
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
|
|
|
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
|
|
|
from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
|
from hivemind.client.averaging.group_info import GroupInfo
|
|
@@ -71,6 +71,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
|
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
|
:param kwargs: extra parameters forwarded to grpc.aio.server
|
|
|
+ :param auxiliary: if this flag is specified, averager.step will only assist others without sending
|
|
|
+ local tensors for averaging
|
|
|
+ :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
|
+ with averager.allow_state_sharing = True / False
|
|
|
|
|
|
Example:
|
|
|
|
|
@@ -94,6 +98,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
|
|
|
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
|
throughput: Optional[float] = None, min_vector_size: int = 0,
|
|
|
+ auxiliary: bool = False, allow_state_sharing: Optional[bool] = None,
|
|
|
listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True,
|
|
|
channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
|
|
|
assert '.' not in prefix, "group prefix must be a string without trailing '.'"
|
|
@@ -102,10 +107,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
if not is_power_of_two(target_group_size):
|
|
|
logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
|
assert initial_group_bits is None or all(bit in '01' for bit in initial_group_bits)
|
|
|
+ assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
|
|
|
+ if not self.listen:
|
|
|
+ self.mode = AveragingMode.CLIENT
|
|
|
+ elif auxiliary:
|
|
|
+ self.mode = AveragingMode.AUX
|
|
|
+ else:
|
|
|
+ self.mode = AveragingMode.NODE
|
|
|
+
|
|
|
self.channel_options = channel_options
|
|
|
self.daemon = daemon
|
|
|
|
|
@@ -129,6 +142,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
self._pipe, self.pipe = mp.Pipe(duplex=True) # a control pipe used to communicate with a background process
|
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
|
+
|
|
|
+ self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
+ self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
|
|
|
+
|
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
|
if not self.listen:
|
|
|
self._averager_endpoint = f'client::{uuid.uuid4()}'
|
|
@@ -146,6 +163,18 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
def port(self) -> Optional[Port]:
|
|
|
return self._port.value if self._port.value != 0 else None
|
|
|
|
|
|
+ @property
|
|
|
+ def allow_state_sharing(self) -> bool:
|
|
|
+ """ if set to True, other peers can download this peer's state """
|
|
|
+ return bool(self._allow_state_sharing.value)
|
|
|
+
|
|
|
+ @allow_state_sharing.setter
|
|
|
+ def allow_state_sharing(self, value: bool):
|
|
|
+ if value is True and not self.listen:
|
|
|
+ logger.warning("Cannot allow state sharing: averager in client mode (listen=False) cannot share its state.")
|
|
|
+ else:
|
|
|
+ self._allow_state_sharing.value = value
|
|
|
+
|
|
|
@property
|
|
|
def endpoint(self) -> Optional[Endpoint]:
|
|
|
if self.listen and self._averager_endpoint is None:
|
|
@@ -236,7 +265,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
|
- assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
+ if self.mode == AveragingMode.AUX and weight != 1:
|
|
|
+ logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
+ else:
|
|
|
+ assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
+
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
|
self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
|
|
@@ -253,7 +286,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
while not future.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
- data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
|
|
|
+ data_for_gather = self.serializer.dumps([weight, self._throughput, self.mode.value, gather_binary])
|
|
|
group_info = await self._matchmaking.look_for_group(timeout=timeout,
|
|
|
data_for_gather=data_for_gather)
|
|
|
if group_info is None:
|
|
@@ -263,7 +296,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._running_groups[group_id] = allreduce_runner
|
|
|
self._pending_group_assembled.set()
|
|
|
await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
|
|
|
- await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
+ if self.mode != AveragingMode.AUX:
|
|
|
+ await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
|
|
|
# averaging is finished, exit the loop
|
|
|
future.set_result(allreduce_runner.gathered)
|
|
@@ -293,19 +327,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
|
""" Use a group description found by Matchmaking to form AllreduceRunner """
|
|
|
try:
|
|
|
- weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
+ weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
|
|
|
-
|
|
|
# compute optimal part sizes from peer throughputs
|
|
|
- incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
|
|
|
+ modes = tuple(map(AveragingMode, mode_ids))
|
|
|
+ incoming_throughputs = [thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)] # TODO: replace with proper load balancing
|
|
|
part_sizes = await asyncio.get_event_loop().run_in_executor(
|
|
|
None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
|
|
|
async with self.get_tensors_async() as averaged_tensors:
|
|
|
return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
|
|
|
ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
|
|
|
- weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
|
|
|
+ weights=weights, gathered=user_gathered, return_deltas=True, modes=modes, **kwargs)
|
|
|
except Exception as e:
|
|
|
- raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
|
|
|
+ raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {weights, throughputs, modes, user_gathered}")
|
|
|
|
|
|
def update_tensors(self, allreduce_group: AllReduceRunner):
|
|
|
"""
|
|
@@ -366,10 +400,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
async def _declare_for_download_periodically(self):
|
|
|
download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
|
|
|
while True:
|
|
|
- asyncio.create_task(asyncio.wait_for(self.dht.store(
|
|
|
- download_key, subkey=self.endpoint, value=self.last_updated,
|
|
|
- expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
|
|
|
- timeout=self._matchmaking.averaging_expiration))
|
|
|
+ if self.allow_state_sharing:
|
|
|
+ asyncio.create_task(asyncio.wait_for(self.dht.store(
|
|
|
+ download_key, subkey=self.endpoint, value=self.last_updated,
|
|
|
+ expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
|
|
|
+ timeout=self._matchmaking.averaging_expiration))
|
|
|
await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
|
|
|
async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
|
|
@@ -381,6 +416,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
|
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
|
"""
|
|
|
+ if not self.allow_state_sharing:
|
|
|
+ return # deny request and direct peer to the next prospective averager
|
|
|
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
|
|
|
metadata, tensors = await self._get_current_state_from_host_process()
|
|
|
|
|
@@ -452,6 +489,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
current_tensor_parts.append(message.tensor_part)
|
|
|
if current_tensor_parts:
|
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
+
|
|
|
+ if not metadata:
|
|
|
+ logger.debug(f"Peer {peer} did not send its state.")
|
|
|
+ continue
|
|
|
+
|
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
|
future.set_result((metadata, tensors))
|
|
|
self.last_updated = get_dht_time()
|