Your Name 1 anno fa
parent
commit
cc4fe17a99
3 ha cambiato i file con 7 aggiunte e 3 eliminazioni
  1. 3 0
      src/petals/__init__.py
  2. 1 1
      src/petals/server/backend.py
  3. 3 2
      src/petals/server/server.py

+ 3 - 0
src/petals/__init__.py

@@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
         version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
     ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
+    assert version.parse("1.1.10") <= version.parse(
+        hivemind.__version__
+    ), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
 
 
 def _override_bfloat16_mode_default():

+ 1 - 1
src/petals/server/backend.py

@@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
     first_pool = next(iter(backends.values())).inference_pool
     merged_inference_func = _MergedInferenceStep(backends)
     merged_pool = PrioritizedTaskPool(
-        lambda args, kwargs: merged_inference_func(*args, **kwargs),
+        merged_inference_func,
         max_batch_size=first_pool.max_batch_size,
         device=first_pool.device,
         name=f"merged_inference",

+ 3 - 2
src/petals/server/server.py

@@ -8,7 +8,7 @@ import random
 import sys
 import threading
 import time
-from typing import Dict, List, Optional, Sequence, Union
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 
 import hivemind
 import psutil
@@ -17,6 +17,7 @@ import torch.mps
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
+from hivemind.moe.server.task_pool import TaskPoolBase
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
@@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime):
         outputs = pool.process_func(*args, **kwargs)
         batch_size = 1
         for arg in args:
-            if isintance(arg, torch.Tensor) and arg.ndim > 2:
+            if isinstance(arg, torch.Tensor) and arg.ndim > 2:
                 batch_size = arg.shape[0] * arg.shape[1]
                 break
         return outputs, batch_size