Michael Diskin 4 жил өмнө
parent
commit
1c00a9e5c7

+ 9 - 5
hivemind/averaging/averager.py

@@ -317,11 +317,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
+        logger.debug(f"mer 0")
         future = MPFuture()
+        logger.debug(f"mer 1")
         gather_binary = self.serializer.dumps(
             gather
         )  # serialize here to avoid loading modules in the averager process
+        logger.debug(f"mer 2")
         self._outer_pipe.send(
             (
                 "_step",
@@ -335,26 +337,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 ),
             )
         )
+        logger.debug(f"mer 5")
         return future.result() if wait else future
 
     async def _step(
         self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
     ):
+        logger.debug(f"be 0")
         start_time = get_dht_time()
 
         try:
             while not future.done():
                 try:
-                    logger.warning(f"be 1")
+                    logger.debug(f"be 1")
                     self._pending_group_assembled.clear()
-                    logger.warning(f"be 2")
+                    logger.debug(f"be 2")
 
                     data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
-                    logger.warning(f"be 1")
+                    logger.debug(f"be 3")
                     group_info = await self._matchmaking.look_for_group(
                         timeout=timeout, data_for_gather=data_for_gather
                     )
-                    logger.warning(f"be 3")
+                    logger.debug(f"be 4")
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")