Browse Source

Make client ignore blacklist if all servers holding a block are blacklisted (#197)

If all servers holding a certain block are blacklisted, we should display errors from them instead of raising `No peers holding blocks`.

Indeed, if the error is client-caused, the client should learn its reason from the latest error messages. In turn, if the error is server/network-caused and we only have a few servers, we'd better know the error instead of banning all the servers and making the user think that no servers are available.
Alexander Borzunov 2 years ago
parent
commit
b4f3224cda

+ 2 - 2
src/petals/client/inference_session.py

@@ -17,7 +17,7 @@ from hivemind import (
     serialize_torch_tensor,
     serialize_torch_tensor,
 )
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2PHandlerError, StubBase
+from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
 from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
 from petals.client.routing.sequence_manager import RemoteSequenceManager, maybe_log_traceback
@@ -305,7 +305,7 @@ class InferenceSession:
                     self._sequence_manager.on_request_success(span.peer_id)
                     self._sequence_manager.on_request_success(span.peer_id)
                     break
                     break
                 except Exception as e:
                 except Exception as e:
-                    if span is not None and not isinstance(e, P2PHandlerError):
+                    if span is not None:
                         self._sequence_manager.on_request_failure(span.peer_id)
                         self._sequence_manager.on_request_failure(span.peer_id)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
                     logger.warning(

+ 14 - 4
src/petals/client/routing/sequence_manager.py

@@ -156,10 +156,20 @@ class RemoteSequenceManager:
                 for block_info in new_block_infos:
                 for block_info in new_block_infos:
                     if not block_info:
                     if not block_info:
                         continue
                         continue
-                    for peer_id in tuple(block_info.servers.keys()):
-                        if peer_id in self.banned_peers:
-                            logger.debug(f"Ignoring banned {peer_id} for block {block_info.uid}")
-                            block_info.servers.pop(peer_id)
+                    valid_servers = {
+                        peer_id: server_info
+                        for peer_id, server_info in block_info.servers.items()
+                        if peer_id not in self.banned_peers
+                    }
+                    if len(valid_servers) < len(block_info.servers):
+                        if valid_servers:
+                            logger.debug(
+                                f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
+                            )
+                            block_info.servers = valid_servers
+                        else:
+                            # If we blacklisted all servers, the error may actually be client-caused
+                            logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
 
 
                 with self.lock_changes:
                 with self.lock_changes:
                     self.sequence_info.update_(new_block_infos)
                     self.sequence_info.update_(new_block_infos)

+ 2 - 3
src/petals/client/sequential_autograd.py

@@ -10,7 +10,6 @@ from typing import List, Optional, Sequence, Tuple
 import torch
 import torch
 from hivemind import MSGPackSerializer
 from hivemind import MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2PHandlerError
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
@@ -94,7 +93,7 @@ async def sequential_forward(
                 sequence_manager.on_request_success(span.peer_id)
                 sequence_manager.on_request_success(span.peer_id)
                 break
                 break
             except Exception as e:
             except Exception as e:
-                if span is not None and not isinstance(e, P2PHandlerError):
+                if span is not None:
                     sequence_manager.on_request_failure(span.peer_id)
                     sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                 logger.warning(
@@ -171,7 +170,7 @@ async def sequential_backward(
                 sequence_manager.on_request_success(span.peer_id)
                 sequence_manager.on_request_success(span.peer_id)
                 break
                 break
             except Exception as e:
             except Exception as e:
-                if span is not None and not isinstance(e, P2PHandlerError):
+                if span is not None:
                     sequence_manager.on_request_failure(span.peer_id)
                     sequence_manager.on_request_failure(span.peer_id)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                 logger.warning(