ソースを参照

Don't ban experts for timeout

Max Ryabinin 3 年 前
コミット
b9ccbe7b48
2 ファイル変更12 行追加2 行削除
  1. 2 2
      hivemind/moe/client/balanced_expert.py
  2. 10 0
      hivemind/moe/client/balancer.py

+ 2 - 2
hivemind/moe/client/balanced_expert.py

@@ -126,7 +126,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             except KeyboardInterrupt:
                 raise
             except BaseException as e:
-                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
+                logger.exception(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         return tuple(deserialized_outputs)
@@ -150,6 +150,6 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             except KeyboardInterrupt:
                 raise
             except BaseException as e:
-                logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
+                logger.exception(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, None, None, None, *deserialized_grad_inputs)

+ 10 - 0
hivemind/moe/client/balancer.py

@@ -5,6 +5,9 @@ from contextlib import contextmanager
 from typing import Dict, List, Tuple
 import time
 
+import grpc
+from grpc._channel import _InactiveRpcError
+
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
@@ -127,6 +130,13 @@ class ExpertBalancer:
                 logger.debug(f"Using expert {uid}, throughput = {self.throughputs[uid].samples_per_second}.")
                 yield RemoteExpert(uid, maybe_endpoint.value)
                 logger.debug(f"Finished using expert {uid}.")
+        except _InactiveRpcError as error:
+            if error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
+                # response was too slow, choose the next expert
+                pass
+            else:
+                self._ban_expert(uid)
+                raise
         except BaseException:
             self._ban_expert(uid)
             raise