Ver código fonte

Fix bugs/circular imports, raise timeouts

Max Ryabinin 4 anos atrás
pai
commit
ede2be17ca

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -79,7 +79,7 @@ def main():
 
     parser.add_argument('--averaging_min_refresh_period',type=float,default=1)
     parser.add_argument('--averaging_max_refresh_period',type=float,default=60)
-    parser.add_argument('--averaging_default_refresh_period',type=float,default=5)
+    parser.add_argument('--averaging_default_refresh_period',type=float,default=10)
     parser.add_argument('--averaging_expiration',type=float,default=30)
     parser.add_argument('--metadata_expiration',type=float,default=120)
     parser.add_argument('--averaging_timeout',type=float,default=30)

+ 12 - 10
hivemind/moe/client/balanced_expert.py

@@ -4,8 +4,8 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.moe.client.balancer import ExpertBalancer
 from hivemind.moe.client.expert import DUMMY
 from hivemind.proto import runtime_pb2
@@ -23,7 +23,7 @@ class BalancedRemoteExpert(nn.Module):
     def __init__(
         self,
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         uid_prefix: str,
         grid_size: Tuple[int, ...],
         forward_timeout: Optional[float] = None,
@@ -118,13 +118,15 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
         while True:
-            try:
-                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
-                    forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
-                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
-                break
-            except BaseException as e:
-                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
+            # try:
+            with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
+                outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+            break
+            # except KeyboardInterrupt:
+            #     break
+            # except BaseException as e:
+            #     logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         return tuple(deserialized_outputs)
@@ -143,7 +145,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             try:
                 with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
                     backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
-                    grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
+                    grad_inputs = chosen_expert.stub.backward(backward_request, timeout=ctx.backward_timeout)
                 break
             except BaseException as e:
                 logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")

+ 3 - 3
hivemind/moe/client/balancer.py

@@ -4,11 +4,11 @@ import threading
 from contextlib import contextmanager
 from typing import Dict, List, Tuple
 
-from hivemind import Endpoint, RemoteExpert, TimedStorage
 from hivemind.dht import DHT
+from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import ExpertPrefix, ExpertUID
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import Endpoint, TimedStorage, DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -71,7 +71,7 @@ class ExpertBalancer:
             self.experts.store(uid, endpoint, expiration_time)
             if uid not in self.uid_to_queue:
                 logger.debug(f"Adding new expert: {uid}, expiration time = {expiration_time:.3f}.")
-                self.throughputs[uid] = PerformanceEMA(*self.ema_kwargs, paused=True)
+                self.throughputs[uid] = PerformanceEMA(**self.ema_kwargs, paused=True)
                 base_load = self.queue[0][0] if len(self.queue) > 0 else 0.0
                 heap_entry = (base_load, random.random(), uid)
                 heapq.heappush(self.queue, heap_entry)

+ 4 - 2
hivemind/moe/server/__init__.py

@@ -12,6 +12,7 @@ import torch
 from multiaddr import Multiaddr
 
 import hivemind
+from hivemind.compression import BASE_COMPRESSION_TYPES
 from hivemind.dht import DHT
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
@@ -113,7 +114,7 @@ class Server(threading.Thread):
         averaging_target_group_size: Optional[int] = None,
         averaging_min_refresh_period=1,
         averaging_max_refresh_period=60,
-        averaging_default_refresh_period=5,
+        averaging_default_refresh_period=10,
         averaging_expiration=30,
         metadata_expiration=120,
         averaging_timeout=30,
@@ -249,7 +250,8 @@ class Server(threading.Thread):
                     optim,
                     dht=dht,
                     prefix=expert_uid.split(UID_DELIMITER)[0],
-                    compression_type=CompressionType.Value(averaging_compression),
+                    compression=BASE_COMPRESSION_TYPES[averaging_compression],
+                    state_compression=BASE_COMPRESSION_TYPES[averaging_compression],
                     target_batch_size=averaging_target_batch_size,
                     target_group_size=averaging_target_group_size,
                     min_refresh_period=averaging_min_refresh_period,

+ 2 - 1
hivemind/optim/performance_ema.py

@@ -1,5 +1,6 @@
 from contextlib import contextmanager
 from threading import Lock
+from typing import Optional
 
 from hivemind.utils import get_dht_time
 
@@ -17,7 +18,7 @@ class PerformanceEMA:
         self.paused = paused
         self.lock = Lock()
 
-    def update(self, task_size: float, interval: float) -> float:
+    def update(self, task_size: float, interval: Optional[float] = None) -> float:
         """
         :param task_size: how many items were processed since last call
         :param interval: optionally provide the time delta it took to process this task