|
@@ -11,7 +11,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
-import hivemind
|
|
|
+from hivemind import moe
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
@@ -32,8 +32,8 @@ from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
|
|
|
|
|
|
|
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
|
|
|
- return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandlerStub":
|
|
|
+ return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
@@ -251,13 +251,12 @@ async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions:
|
|
|
class _RemoteModuleCall(torch.autograd.Function):
|
|
|
"""Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
|
|
|
|
|
|
- @classmethod
|
|
|
+ @staticmethod
|
|
|
def forward(
|
|
|
- cls,
|
|
|
ctx,
|
|
|
dummy: torch.Tensor,
|
|
|
uid: str,
|
|
|
- stub, #: ConnectionHandlerStub,
|
|
|
+ stub: "ConnectionHandlerStub",
|
|
|
info: Dict[str, Any],
|
|
|
*inputs: torch.Tensor,
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
@@ -273,9 +272,9 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
|
- @classmethod
|
|
|
+ @staticmethod
|
|
|
@once_differentiable
|
|
|
- def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
+ def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|