Sfoglia il codice sorgente

Hotfix logging (Task was destroyed but is pending _put_items) (#427)

This PR reduces the frequency at which training prints these warnings to an extent that it is now bearable to train.
However, the warnings persist in some cases, such as when training step is cancelled


Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 anni fa
parent
commit
9a921afee3

+ 2 - 2
hivemind/averaging/allreduce.py

@@ -242,7 +242,7 @@ class AllReduceRunner(ServicerBase):
                     )
                     )
             except BaseException as e:
             except BaseException as e:
                 if isinstance(e, Exception):
                 if isinstance(e, Exception):
-                    logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
+                    logger.debug(f"Caught {repr(e)} when communicating to {peer_id}", exc_info=True)
                 self.tensor_part_container.register_failed_reducer(peer_index)
                 self.tensor_part_container.register_failed_reducer(peer_index)
                 raise
                 raise
 
 
@@ -311,7 +311,7 @@ class AllReduceRunner(ServicerBase):
         except BaseException as e:
         except BaseException as e:
             await self._ban_sender(context.remote_id)
             await self._ban_sender(context.remote_id)
             if isinstance(e, Exception):
             if isinstance(e, Exception):
-                logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
+                logger.debug(f"Caught {repr(e)} when communicating with {context.remote_id}", exc_info=True)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             else:
             else:
                 raise  # CancelledError, StopIteration and similar
                 raise  # CancelledError, StopIteration and similar

+ 1 - 1
hivemind/averaging/partition.py

@@ -262,7 +262,7 @@ class TensorPartReducer:
                 parts_expected = self.num_parts * self.num_senders
                 parts_expected = self.num_parts * self.num_senders
                 parts_received = sum(self.num_parts_received)
                 parts_received = sum(self.num_parts_received)
                 if parts_expected != parts_received:
                 if parts_expected != parts_received:
-                    logger.info(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+                    logger.warning(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
 
 
     def __del__(self):
     def __del__(self):
         self.finalize()
         self.finalize()

+ 11 - 10
hivemind/utils/asyncio.py

@@ -106,7 +106,7 @@ async def cancel_and_wait(awaitable: Awaitable) -> bool:
 async def amap_in_executor(
 async def amap_in_executor(
     func: Callable[..., T],
     func: Callable[..., T],
     *iterables: AsyncIterable,
     *iterables: AsyncIterable,
-    max_prefetch: Optional[int] = None,
+    max_prefetch: int = 1,
     executor: Optional[ThreadPoolExecutor] = None,
     executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     """iterate from an async iterable in a background thread, yield results to async iterable"""
@@ -122,7 +122,6 @@ async def amap_in_executor(
             future = asyncio.Future()
             future = asyncio.Future()
             future.set_exception(e)
             future.set_exception(e)
             await queue.put(future)
             await queue.put(future)
-            raise
 
 
     task = asyncio.create_task(_put_items())
     task = asyncio.create_task(_put_items())
     try:
     try:
@@ -131,17 +130,19 @@ async def amap_in_executor(
             yield await future
             yield await future
             future = await queue.get()
             future = await queue.get()
     finally:
     finally:
-        task.cancel()
-        try:
-            await task
-        except asyncio.CancelledError:
-            pass
-        except Exception as e:
-            logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+        awaitables = [task]
         while not queue.empty():
         while not queue.empty():
             future = queue.get_nowait()
             future = queue.get_nowait()
             if future is not None:
             if future is not None:
-                future.cancel()
+                awaitables.append(future)
+        for coro in awaitables:
+            coro.cancel()
+            try:
+                await coro
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+                # note: we do not reraise here because it is already in the finally clause
 
 
 
 
 async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
 async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]: