Prechádzať zdrojové kódy

Handle errors in Runtime (#489)

- fix edge case where expert requests with 3.99-4MB payload would fail due to max message size (due to serialization overhead)
- recover from errors in the Runtime, propagate them to the corresponding tasks
   - previously, a failing function would terminate the entire server - which was a major pain for me personally :)
   - failure to process a request will now trigger P2PHandlerError instead of P2PDaemonError (cuz it does not kill the daemon)
- allow optional metadata in ExpertRequest / ExpertResponse for extendability [todo: validate it vs. @mryab ]

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Pavel Samygin <samygin@phystech.edu>
justheuristic 3 rokov pred
rodič
commit
ef0b842baf

+ 3 - 3
hivemind/moe/client/expert.py

@@ -13,7 +13,7 @@ from hivemind.dht import DHT
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, PeerID, StubBase
-from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
 from hivemind.utils.mpfuture import MPFuture
@@ -152,7 +152,7 @@ async def expert_backward(
     size = 0
     for t in inputs_and_grads:
         size += t.element_size() * t.nelement()
-        if size > DEFAULT_MAX_MSG_SIZE:
+        if size > MAX_UNARY_PAYLOAD_SIZE:
             return await _backward_stream(uid, serialized_tensors, stub)
     else:
         return await _backward_unary(uid, serialized_tensors, stub)
@@ -185,7 +185,7 @@ async def expert_forward(
     size = 0
     for t in inputs:
         size += t.element_size() * t.nelement()
-        if size > DEFAULT_MAX_MSG_SIZE:
+        if size > MAX_UNARY_PAYLOAD_SIZE:
             return await _forward_stream(uid, serialized_tensors, stub)
     else:
         return await _forward_unary(uid, serialized_tensors, stub)

+ 14 - 7
hivemind/moe/server/runtime.py

@@ -85,16 +85,23 @@ class Runtime(threading.Thread):
                     logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
 
                     start = time()
-                    outputs = pool.process_func(*batch)
-                    batch_processing_time = time() - start
+                    try:
+                        outputs = pool.process_func(*batch)
+                        batch_processing_time = time() - start
 
-                    batch_size = outputs[0].size(0)
-                    logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
+                        batch_size = outputs[0].size(0)
+                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
 
-                    if self.stats_report_interval is not None:
-                        self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+                        if self.stats_report_interval is not None:
+                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
+
+                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
+                    except KeyboardInterrupt:
+                        raise
+                    except BaseException as exception:
+                        logger.exception(f"Caught {exception}, attempting to recover")
+                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
 
-                    output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
                 if not self.shutdown_trigger.is_set():
                     self.shutdown()

+ 30 - 13
hivemind/moe/server/task_pool.py

@@ -195,21 +195,35 @@ class TaskPool(TaskPoolBase):
 
         while True:
             logger.debug(f"{self.name} waiting for results from runtime")
-            batch_index, batch_outputs = self.outputs_receiver.recv()
-            logger.debug(f"{self.name}, batch {batch_index}: got results")
-
-            # split batch into partitions for individual tasks
+            batch_index, batch_outputs_or_exception = self.outputs_receiver.recv()
             batch_tasks = pending_batches.pop(batch_index)
-            task_sizes = [self.get_task_size(task) for task in batch_tasks]
-            outputs_per_task = zip(*(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs))
-            logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
 
-            # dispatch results to futures
-            for task, task_outputs in zip(batch_tasks, outputs_per_task):
-                try:
-                    task.future.set_result(tuple(task_outputs))
-                except InvalidStateError as e:
-                    logger.debug(f"Failed to send task result due to an exception: {e}")
+            if isinstance(batch_outputs_or_exception, BaseException):
+                logger.debug(f"{self.name}, batch {batch_index}: got exception, propagating to handlers")
+                exception = batch_outputs_or_exception
+                for task in batch_tasks:
+                    try:
+                        task.future.set_exception(exception)
+                    except InvalidStateError as e:
+                        logger.debug(f"Failed to send runtime error to a task: {e}")
+
+            else:
+                logger.debug(f"{self.name}, batch {batch_index}: got results")
+                batch_outputs = batch_outputs_or_exception
+
+                # split batch into partitions for individual tasks
+                task_sizes = [self.get_task_size(task) for task in batch_tasks]
+                outputs_per_task = zip(
+                    *(torch.split_with_sizes(tensor, task_sizes, dim=0) for tensor in batch_outputs)
+                )
+                logger.debug(f"{self.name}, batch {batch_index}: sending outputs to handlers")
+
+                # dispatch results to futures
+                for task, task_outputs in zip(batch_tasks, outputs_per_task):
+                    try:
+                        task.future.set_result(tuple(task_outputs))
+                    except InvalidStateError as e:
+                        logger.debug(f"Failed to send task result due to an exception: {e}")
 
     @property
     def empty(self):
@@ -232,6 +246,9 @@ class TaskPool(TaskPoolBase):
         ]
         self.outputs_sender.send((batch_index, batch_outputs))
 
+    def send_exception_from_runtime(self, batch_index: int, exception: BaseException):
+        self.outputs_sender.send((batch_index, exception))
+
     def get_task_size(self, task: Task) -> int:
         """compute task processing complexity (used for batching); defaults to batch size"""
         return len(task.args[0]) if task.args else 1

+ 3 - 0
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -27,6 +27,9 @@ SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CON
 logger = get_logger(__name__)
 
 DEFAULT_MAX_MSG_SIZE = 4 * 1024**2
+MAX_UNARY_PAYLOAD_SIZE = DEFAULT_MAX_MSG_SIZE // 2
+# note: we check vs. 2x max message size to account for serialization overhead. The actual overhead is
+# typically smaller. We err on the side of streaming, because even 2MB messages can be streamed efficiently.
 
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:

+ 2 - 0
hivemind/proto/runtime.proto

@@ -12,10 +12,12 @@ message ExpertInfo {
 message ExpertRequest {
   string uid = 1;
   repeated Tensor tensors = 2;
+  string metadata = 3;
 }
 
 message ExpertResponse {
   repeated Tensor tensors = 2;
+  string metadata = 3;
 }
 
 enum CompressionType{

+ 9 - 3
tests/test_moe.py

@@ -9,7 +9,7 @@ from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
 from hivemind.moe.server.layers import name_to_block
-from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
+from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
 from hivemind.utils import BatchTensorDescriptor, get_dht_time
 
 
@@ -153,11 +153,17 @@ def test_remote_module_call(hidden_dim=16):
         out3_again.norm().backward()
         assert dummy_x.grad is not None and dummy_x.grad.norm() > 0
 
-        with pytest.raises(P2PDaemonError):
+        try:
             real_expert(torch.randn(3, 11))
-        with pytest.raises(P2PDaemonError):
+        except P2PHandlerError as e:
+            assert str(11) in repr(e), "Exception must relay the remote server error (i.e. incorrect dimensions)"
+        with pytest.raises(P2PHandlerError):
             fake_expert(dummy_x)
 
+        # check that the server is still alive after processing a malformed request
+        out3_yet_again = real_expert(dummy_x[1:])
+        assert torch.allclose(out3_yet_again, out3[1:])
+
 
 @pytest.mark.forked
 def test_beam_search_correctness():