|
@@ -32,7 +32,7 @@ from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
|
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop, azip
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -344,26 +344,33 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
|
|
|
if weight is None:
|
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
|
+ deadline = get_dht_time() + timeout if timeout else None
|
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
|
|
|
- gather_binary = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
|
|
|
+ assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
|
|
|
|
|
|
- step = StepControl(scheduled_time, weight, wait_for_trigger=wait_for_trigger,
|
|
|
- gather_binary=gather_binary, timeout=timeout, allow_retries=allow_retries)
|
|
|
- self._outer_pipe.send(("_step", [], dict(step=step)))
|
|
|
+ user_gather_bytes = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
|
|
|
+ gather_binary = self.serializer.dumps([self.bandwidth, self.mode.value, user_gather_bytes])
|
|
|
+ step = StepControl(scheduled_time=scheduled_time, deadline=deadline, allow_retries=allow_retries,
|
|
|
+ weight=weight, gather_binary=gather_binary)
|
|
|
+
|
|
|
+ future_for_trigger = MPFuture()
|
|
|
+ self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
|
|
|
+ step.attach_trigger(future_for_trigger.result())
|
|
|
+
|
|
|
+ if not wait_for_trigger:
|
|
|
+ step.allow_allreduce()
|
|
|
return step.result() if wait else step
|
|
|
|
|
|
- async def _step(
|
|
|
- self, *, step: StepControl, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]
|
|
|
- ):
|
|
|
- start_time = get_dht_time()
|
|
|
+ async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
|
|
|
+ trigger = MPFuture()
|
|
|
+ step.attach_trigger(trigger)
|
|
|
+ future_for_trigger.set_result(trigger)
|
|
|
|
|
|
try:
|
|
|
while not step.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
- data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
|
|
|
-
|
|
|
step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
group_info = await self._matchmaking.look_for_group(step)
|
|
|
if group_info is None:
|
|
@@ -394,12 +401,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.InvalidStateError,
|
|
|
P2PHandlerError,
|
|
|
) as e:
|
|
|
- time_elapsed = get_dht_time() - start_time
|
|
|
- if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
- logger.exception(f"Averager caught {repr(e)}")
|
|
|
+ if not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
+ logger.exception(f"{self.__class__.__name__} caught {repr(e)}")
|
|
|
step.set_exception(e)
|
|
|
else:
|
|
|
- logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
+ logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
|
|
|
|
|
|
except BaseException as e:
|
|
|
if not step.done():
|
|
@@ -418,8 +424,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
try:
|
|
|
- bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
- user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
|
|
|
+ bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
+ user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
# compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
@@ -445,15 +451,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
)
|
|
|
|
|
|
with self.register_allreduce_group(group_info.group_id, allreduce):
|
|
|
-
|
|
|
- # actually run all-reduce
|
|
|
- averaging_outputs = [output async for output in allreduce]
|
|
|
-
|
|
|
- if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
|
|
|
- assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
- for tensor, update in zip(local_tensors, averaging_outputs):
|
|
|
- tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self.last_updated = get_dht_time()
|
|
|
+ assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
+ async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
+ # note: all-reduce is performed asynchronously when iterating
|
|
|
+ tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
+ self.last_updateod = get_dht_time()
|
|
|
|
|
|
return allreduce.gathered
|
|
|
except BaseException as e:
|