Parcourir la source

Post-merge fixes

Max Ryabinin il y a 4 ans
Parent
commit
5ae492f076

+ 1 - 0
hivemind/__init__.py

@@ -2,6 +2,7 @@ from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
+    BalancedRemoteExpert,
     ExpertBackend,
     RemoteExpert,
     RemoteMixtureOfExperts,

+ 6 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,7 @@
-from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.client import (
+    BalancedRemoteExpert,
+    RemoteExpert,
+    RemoteMixtureOfExperts,
+    RemoteSwitchMixtureOfExperts,
+)
 from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class

+ 1 - 0
hivemind/moe/client/__init__.py

@@ -1,3 +1,4 @@
+from hivemind.moe.client.balanced_expert import BalancedRemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts
 from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts

+ 21 - 19
hivemind/moe/client/balanced_expert.py

@@ -5,10 +5,10 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.balancer import ExpertBalancer
 from hivemind.moe.client.expert import DUMMY
 from hivemind.proto import runtime_pb2
-from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
@@ -66,14 +66,16 @@ class BalancedRemoteExpert(nn.Module):
         forward_task_size = flat_inputs[0].shape[0]
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
-        flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
-                                                       self.expert_balancer,
-                                                       self.info,
-                                                       self.forward_timeout,
-                                                       self.backward_timeout,
-                                                       forward_task_size,
-                                                       forward_task_size * self.backward_task_size_multiplier,
-                                                       *flat_inputs)
+        flat_outputs = _BalancedRemoteModuleCall.apply(
+            DUMMY,
+            self.expert_balancer,
+            self.info,
+            self.forward_timeout,
+            self.backward_timeout,
+            forward_task_size,
+            forward_task_size * self.backward_task_size_multiplier,
+            *flat_inputs,
+        )
 
         return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
 
@@ -93,16 +95,16 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
 
     @staticmethod
     def forward(
-            ctx,
-            dummy: torch.Tensor,
-            expert_balancer: ExpertBalancer,
-            info: Dict[str, Any],
-            forward_timeout: float,
-            backward_timeout: float,
-            forward_task_size: float,
-            backward_task_size: float,
-            *inputs: torch.Tensor,
-            ) -> Tuple[torch.Tensor, ...]:
+        ctx,
+        dummy: torch.Tensor,
+        expert_balancer: ExpertBalancer,
+        info: Dict[str, Any],
+        forward_timeout: float,
+        backward_timeout: float,
+        forward_task_size: float,
+        backward_task_size: float,
+        *inputs: torch.Tensor,
+    ) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         # detach to avoid pickling the computation graph
         ctx.expert_balancer, ctx.info = expert_balancer, info

+ 7 - 4
hivemind/moe/client/balancer.py

@@ -14,8 +14,9 @@ logger = get_logger(__name__)
 
 
 class ExpertBalancer:
-    def __init__(self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
-                 **kwargs):
+    def __init__(
+        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0, **kwargs
+    ):
         self.dht, self.key = dht, key
         self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
         self.experts = TimedStorage[ExpertUID, Endpoint]()
@@ -55,8 +56,10 @@ class ExpertBalancer:
                     except Exception as e:
                         logger.warning(f"Skipping malformed expert info {expert_info} (exc={e})")
             else:
-                logger.warning(f"Could not refresh experts, dht info key contains {response}, "
-                               f"will retry in {time_to_next_update}s")
+                logger.warning(
+                    f"Could not refresh experts, dht info key contains {response}, "
+                    f"will retry in {time_to_next_update}s"
+                )
             if len(self.queue) == 0:
                 logger.warning("Update routine finished, but still no experts available.")
 

+ 3 - 1
hivemind/optim/performance_ema.py

@@ -28,7 +28,9 @@ class PerformanceEMA:
         if not self.paused:
             self.timestamp, old_timestamp = get_dht_time(), self.timestamp
             interval = interval if interval is not None else max(0.0, self.timestamp - old_timestamp)
-        self.ema_seconds_per_sample = self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        self.ema_seconds_per_sample = (
+            self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample
+        )
         self.num_updates += 1
         adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)

+ 1 - 2
tests/test_util_modules.py

@@ -11,11 +11,11 @@ import torch
 
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
@@ -534,7 +534,6 @@ def test_performance_ema_threadsafe(
     bias_power: float = 0.7,
     tolerance: float = 0.05,
 ):
-
     def run_task(ema):
         task_size = random.randint(1, 4)
         with ema.update_threadsafe(task_size):