|
@@ -5,10 +5,10 @@ import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
import hivemind
|
|
|
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.moe.client.balancer import ExpertBalancer
|
|
|
from hivemind.moe.client.expert import DUMMY
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
from hivemind.utils import get_logger, nested_compare, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -66,14 +66,16 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
forward_task_size = flat_inputs[0].shape[0]
|
|
|
|
|
|
# Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
|
|
|
- flat_outputs = _BalancedRemoteModuleCall.apply(DUMMY,
|
|
|
- self.expert_balancer,
|
|
|
- self.info,
|
|
|
- self.forward_timeout,
|
|
|
- self.backward_timeout,
|
|
|
- forward_task_size,
|
|
|
- forward_task_size * self.backward_task_size_multiplier,
|
|
|
- *flat_inputs)
|
|
|
+ flat_outputs = _BalancedRemoteModuleCall.apply(
|
|
|
+ DUMMY,
|
|
|
+ self.expert_balancer,
|
|
|
+ self.info,
|
|
|
+ self.forward_timeout,
|
|
|
+ self.backward_timeout,
|
|
|
+ forward_task_size,
|
|
|
+ forward_task_size * self.backward_task_size_multiplier,
|
|
|
+ *flat_inputs,
|
|
|
+ )
|
|
|
|
|
|
return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
|
|
|
|
|
@@ -93,16 +95,16 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
def forward(
|
|
|
- ctx,
|
|
|
- dummy: torch.Tensor,
|
|
|
- expert_balancer: ExpertBalancer,
|
|
|
- info: Dict[str, Any],
|
|
|
- forward_timeout: float,
|
|
|
- backward_timeout: float,
|
|
|
- forward_task_size: float,
|
|
|
- backward_task_size: float,
|
|
|
- *inputs: torch.Tensor,
|
|
|
- ) -> Tuple[torch.Tensor, ...]:
|
|
|
+ ctx,
|
|
|
+ dummy: torch.Tensor,
|
|
|
+ expert_balancer: ExpertBalancer,
|
|
|
+ info: Dict[str, Any],
|
|
|
+ forward_timeout: float,
|
|
|
+ backward_timeout: float,
|
|
|
+ forward_task_size: float,
|
|
|
+ backward_task_size: float,
|
|
|
+ *inputs: torch.Tensor,
|
|
|
+ ) -> Tuple[torch.Tensor, ...]:
|
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
# detach to avoid pickling the computation graph
|
|
|
ctx.expert_balancer, ctx.info = expert_balancer, info
|