Ver código fonte

Don't ban servers in case of client-caused handler errors (#134)

Alexander Borzunov 3 anos atrás
pai
commit
1fe3716589

+ 2 - 2
src/petals/client/inference_session.py

@@ -17,7 +17,7 @@ from hivemind import (
     serialize_torch_tensor,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import StubBase
+from hivemind.p2p import P2PHandlerError, StubBase
 from hivemind.proto import runtime_pb2
 
 from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
@@ -300,7 +300,7 @@ class InferenceSession:
                     self._sequence_manager.on_request_success(span.peer_id)
                     break
                 except Exception as e:
-                    if span is not None:
+                    if span is not None and not isinstance(e, P2PHandlerError):
                         self._sequence_manager.on_request_failure(span.peer_id)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(

+ 2 - 1
src/petals/client/routing/sequence_manager.py

@@ -12,6 +12,7 @@ from weakref import WeakMethod
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import P2PHandlerError
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
@@ -233,7 +234,7 @@ class RemoteSequenceManager:
                     self.on_request_success(peer_id)
                     break
                 except Exception as e:
-                    if peer_id is not None:
+                    if peer_id is not None and not isinstance(e, P2PHandlerError):
                         self.on_request_failure(peer_id)
                     delay = self.get_retry_delay(attempt_no)
                     logger.warning(

+ 3 - 2
src/petals/client/sequential_autograd.py

@@ -10,6 +10,7 @@ from typing import List, Optional, Sequence, Tuple
 import torch
 from hivemind import MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import P2PHandlerError
 from hivemind.utils.logging import get_logger
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
@@ -93,7 +94,7 @@ async def sequential_forward(
                 sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
-                if span is not None:
+                if span is not None and not isinstance(e, P2PHandlerError):
                     sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
@@ -170,7 +171,7 @@ async def sequential_backward(
                 sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
-                if span is not None:
+                if span is not None and not isinstance(e, P2PHandlerError):
                     sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(