瀏覽代碼

Improve Runtime exception handling (#207)

* Stop dht_handler_thread only if serving experts

* Switch to uvloop in ConnectionHandler

* Properly handle exceptions if setting a result for finished futures in TaskPool
Max Ryabinin 4 年之前
父節點
當前提交
6f8f192150

+ 1 - 1
hivemind/dht/node.py

@@ -121,7 +121,7 @@ class DHTNode:
         :param kwargs: extra parameters used in grpc.aio.server
         """
         self = cls(_initialized_with_create=True)
-        self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
+        self.node_id = node_id if node_id is not None else DHTID.generate()
         self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
         self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh
 

+ 1 - 1
hivemind/server/__init__.py

@@ -205,7 +205,7 @@ class Server(threading.Thread):
 
         for process in self.conn_handlers:
             process.join()
-        if self.dht:
+        if self.dht and self.experts:
             dht_handler_thread.stop.set()
             dht_handler_thread.join()
         if self.checkpoint_saver is not None:

+ 2 - 2
hivemind/server/connection_handler.py

@@ -12,6 +12,7 @@ from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, nested_flatten
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.utils.asyncio import switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -32,8 +33,7 @@ class ConnectionHandler(mp.context.ForkProcess):
 
     def run(self):
         torch.set_num_threads(1)
-        uvloop.install()
-        loop = asyncio.new_event_loop()
+        loop = switch_to_uvloop()
 
         async def _run():
             grpc.aio.init_grpc_aio()

+ 11 - 5
hivemind/server/task_pool.py

@@ -15,7 +15,7 @@ from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
-from hivemind.utils import MPFuture, get_logger
+from hivemind.utils import MPFuture, get_logger, FutureStateError
 
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
@@ -125,9 +125,12 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 total_size = 0
 
-            if task.future.set_running_or_notify_cancel():
-                batch.append(task)
-                total_size += task_size
+            try:
+                if task.future.set_running_or_notify_cancel():
+                    batch.append(task)
+                    total_size += task_size
+            except FutureStateError as e:
+                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
 
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
@@ -199,7 +202,10 @@ class TaskPool(TaskPoolBase):
 
                 # dispatch results to futures
                 for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                    task.future.set_result(tuple(task_outputs))
+                    try:
+                        task.future.set_result(tuple(task_outputs))
+                    except FutureStateError as e:
+                        logger.debug(f"Failed to send task result due to an exception: {e}")
         except KeyboardInterrupt:
             logger.debug(f"Caught KeyboardInterrupt, shutting down")
 

+ 12 - 7
hivemind/utils/mpfuture.py

@@ -1,10 +1,10 @@
 from __future__ import annotations
-import time
-import multiprocessing as mp
-import multiprocessing.connection
-import concurrent.futures._base as base
 
 import asyncio
+import concurrent.futures._base as base
+import multiprocessing as mp
+import multiprocessing.connection
+import time
 from functools import lru_cache
 from typing import Optional, Tuple, Generic, TypeVar
 
@@ -13,6 +13,11 @@ from hivemind.utils.threading import run_in_background
 ResultType = TypeVar('ResultType')
 
 
+class FutureStateError(RuntimeError):
+    """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
+    pass
+
+
 class MPFuture(base.Future, Generic[ResultType]):
     """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
 
@@ -79,14 +84,14 @@ class MPFuture(base.Future, Generic[ResultType]):
     def set_result(self, result: ResultType):
         self._sync_updates()
         if self._state in self.TERMINAL_STATES:
-            raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
+            raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})")
         self._state, self._result = base.FINISHED, result
         return self._send_updates()
 
     def set_exception(self, exception: BaseException):
         self._sync_updates()
         if self._state in self.TERMINAL_STATES:
-            raise RuntimeError(f"Can't set_exception to a future that is in {self._state}")
+            raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})")
         self._state, self._exception = base.FINISHED, exception
         self._send_updates()
 
@@ -98,7 +103,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         elif self._state == base.CANCELLED:
             return False
         else:
-            raise RuntimeError(f"Can't set_running_or_notify_cancel to a future that is in {self._state}")
+            raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})")
 
     def cancel(self):
         self._sync_updates()

+ 14 - 14
tests/test_util_modules.py

@@ -1,15 +1,16 @@
 import asyncio
-import torch
-import numpy as np
+from concurrent.futures import CancelledError
 
+import numpy as np
 import pytest
-import hivemind
+import torch
+
 from hivemind.proto.dht_pb2_grpc import DHTStub
-from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer
-from concurrent.futures import CancelledError
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
+import hivemind
+from hivemind.utils import MSGPackSerializer, serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.mpfuture import FutureStateError
 
 
 def test_mpfuture_result():
@@ -19,9 +20,9 @@ def test_mpfuture_result():
     assert f1.result() == 321
 
     for future in [f1, f2]:
-        with pytest.raises(RuntimeError):
+        with pytest.raises(FutureStateError):
             future.set_result(123)
-        with pytest.raises(RuntimeError):
+        with pytest.raises(FutureStateError):
             future.set_exception(ValueError())
         assert future.cancel() is False
         assert future.done() and not future.running() and not future.cancelled()
@@ -58,9 +59,9 @@ def test_mpfuture_cancel():
             future.result()
         with pytest.raises(CancelledError):
             future.exception()
-        with pytest.raises(RuntimeError):
+        with pytest.raises(FutureStateError):
             future.set_result(123)
-        with pytest.raises(RuntimeError):
+        with pytest.raises(FutureStateError):
             future.set_exception(NotImplementedError())
         assert future.cancelled() and future.done() and not future.running()
 
@@ -204,7 +205,6 @@ def test_serialize_tensor():
     assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
 
 
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),
@@ -264,5 +264,5 @@ def test_generic_data_classes():
     assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
 
     sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
-    sorted_heap_entry = sorted([HeapEntry(expiration_time=DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
-    assert all([heap_entry.expiration_time == value for heap_entry, value in zip(sorted_heap_entry, sorted_expirations)])
+    sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
+    assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])