|
@@ -317,11 +317,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if weight is None:
|
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
-
|
|
|
+ logger.debug(f"mer 0")
|
|
|
future = MPFuture()
|
|
|
+ logger.debug(f"mer 1")
|
|
|
gather_binary = self.serializer.dumps(
|
|
|
gather
|
|
|
) # serialize here to avoid loading modules in the averager process
|
|
|
+ logger.debug(f"mer 2")
|
|
|
self._outer_pipe.send(
|
|
|
(
|
|
|
"_step",
|
|
@@ -335,26 +337,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
+ logger.debug(f"mer 5")
|
|
|
return future.result() if wait else future
|
|
|
|
|
|
async def _step(
|
|
|
self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
|
|
|
):
|
|
|
+ logger.debug(f"be 0")
|
|
|
start_time = get_dht_time()
|
|
|
|
|
|
try:
|
|
|
while not future.done():
|
|
|
try:
|
|
|
- logger.warning(f"be 1")
|
|
|
+ logger.debug(f"be 1")
|
|
|
self._pending_group_assembled.clear()
|
|
|
- logger.warning(f"be 2")
|
|
|
+ logger.debug(f"be 2")
|
|
|
|
|
|
data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
|
|
|
- logger.warning(f"be 1")
|
|
|
+ logger.debug(f"be 3")
|
|
|
group_info = await self._matchmaking.look_for_group(
|
|
|
timeout=timeout, data_for_gather=data_for_gather
|
|
|
)
|
|
|
- logger.warning(f"be 3")
|
|
|
+ logger.debug(f"be 4")
|
|
|
if group_info is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|