Parcourir la source

remove debug messages

justheuristic il y a 4 ans
Parent
commit
83cfb74b1a
3 fichiers modifiés avec 20 ajouts et 29 suppressions
  1. 0 25
      examples/albert/run_trainer.py
  2. 9 4
      hivemind/averaging/averager.py
  3. 11 0
      hivemind/utils/asyncio.py

+ 0 - 25
examples/albert/run_trainer.py

@@ -30,31 +30,6 @@ logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 
 
-def analyze_openfiles_periodically():
-    while True:
-        logger.info(f"Scanning open files for process {psutil.Process().pid}")
-        children = [psutil.Process()] + psutil.Process().children(recursive=True)
-        for child in children:
-            try:
-                num_open_files = len(child.open_files())
-            except:
-                num_open_files = "FAILED"
-            logger.info(f"proc: '{child.name()}' pid={child.pid} parent={child.parent().pid} files: {num_open_files}")
-        for child in children:
-            try:
-                open_files = child.open_files()
-                if len(open_files) > 100:
-                    logger.info(f"proc: {child.name()} has {len(open_files)} open files: {repr(open_files)}")
-            except:
-                pass
-        logger.info("DONE scanning")
-        time.sleep(300)
-
-
-analyzer = threading.Thread(target=analyze_openfiles_periodically)
-analyzer.start()
-
-
 def setup_logging(training_args):
     logging.basicConfig(
         format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",

+ 9 - 4
hivemind/averaging/averager.py

@@ -25,7 +25,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop, aiter_with_timeout
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
@@ -197,6 +197,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def peer_id(self) -> PeerID:
         return self.dht.peer_id
 
+    @property
+    def request_timeout(self):
+        return self._matchmaking.request_timeout
+
     def run(self):
         """
         Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
@@ -245,7 +249,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             while True:
                 try:
-                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._matchmaking.request_timeout)
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
                 except asyncio.TimeoutError:
                     pass
                 if not self._inner_pipe.poll():
@@ -254,7 +258,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     method, args, kwargs = self._inner_pipe.recv()
                 except (OSError, ConnectionError, RuntimeError) as e:
                     logger.exception(e)
-                    await asyncio.sleep(self._matchmaking.request_timeout)
+                    await asyncio.sleep(self.request_timeout)
                     continue
                 task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
                 if method == "_shutdown":
@@ -588,7 +592,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
-                        async for message in stream:
+
+                        async for message in aiter_with_timeout(stream, timeout=self._matchmaking.request_timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:

+ 11 - 0
hivemind/utils/asyncio.py

@@ -127,3 +127,14 @@ async def amap_in_executor(
     finally:
         if not task.done():
             task.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+    """ Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout """
+    # based on https://stackoverflow.com/a/50245879
+    iterator = iterable.__aiter__()
+    while True:
+        try:
+            yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
+        except StopAsyncIteration:
+            break