Browse Source

Hotfix logging (Task was destroyed but is pending _put_items) (#427)

This PR reduces the frequency at which training prints these warnings to an extent that it is now bearable to train.
However, the warnings persist in some cases, such as when training step is cancelled


Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 years ago
parent
commit
9a921afee3
3 changed files with 14 additions and 13 deletions
  1. 2 2
      hivemind/averaging/allreduce.py
  2. 1 1
      hivemind/averaging/partition.py
  3. 11 10
      hivemind/utils/asyncio.py

+ 2 - 2
hivemind/averaging/allreduce.py

@@ -242,7 +242,7 @@ class AllReduceRunner(ServicerBase):
                     )
             except BaseException as e:
                 if isinstance(e, Exception):
-                    logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
+                    logger.debug(f"Caught {repr(e)} when communicating to {peer_id}", exc_info=True)
                 self.tensor_part_container.register_failed_reducer(peer_index)
                 raise
 
@@ -311,7 +311,7 @@ class AllReduceRunner(ServicerBase):
         except BaseException as e:
             await self._ban_sender(context.remote_id)
             if isinstance(e, Exception):
-                logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
+                logger.debug(f"Caught {repr(e)} when communicating with {context.remote_id}", exc_info=True)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             else:
                 raise  # CancelledError, StopIteration and similar

+ 1 - 1
hivemind/averaging/partition.py

@@ -262,7 +262,7 @@ class TensorPartReducer:
                 parts_expected = self.num_parts * self.num_senders
                 parts_received = sum(self.num_parts_received)
                 if parts_expected != parts_received:
-                    logger.info(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+                    logger.warning(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
 
     def __del__(self):
         self.finalize()

+ 11 - 10
hivemind/utils/asyncio.py

@@ -106,7 +106,7 @@ async def cancel_and_wait(awaitable: Awaitable) -> bool:
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,
-    max_prefetch: Optional[int] = None,
+    max_prefetch: int = 1,
     executor: Optional[ThreadPoolExecutor] = None,
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
@@ -122,7 +122,6 @@ async def amap_in_executor(
             future = asyncio.Future()
             future.set_exception(e)
             await queue.put(future)
-            raise
 
     task = asyncio.create_task(_put_items())
     try:
@@ -131,17 +130,19 @@ async def amap_in_executor(
             yield await future
             future = await queue.get()
     finally:
-        task.cancel()
-        try:
-            await task
-        except asyncio.CancelledError:
-            pass
-        except Exception as e:
-            logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+        awaitables = [task]
         while not queue.empty():
             future = queue.get_nowait()
             if future is not None:
-                future.cancel()
+                awaitables.append(future)
+        for coro in awaitables:
+            coro.cancel()
+            try:
+                await coro
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+                # note: we do not reraise here because it is already in the finally clause
 
 
 async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]: