Răsfoiți Sursa

implement partial of oneshot backward/forward

Pavel Samygin 3 ani în urmă
părinte
comite
c782cdf53b
2 a modificat fișierele cu 79 adăugiri și 7 ștergeri
  1. 57 7
      hivemind/moe/client/expert.py
  2. 22 0
      hivemind/moe/server/connection_handler.py

+ 57 - 7
hivemind/moe/client/expert.py

@@ -135,10 +135,25 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
 
+        size = 0
+        for t in inputs:
+            size += t.element_size() * t.nelement()
+            if size >= DEFAULT_MAX_MSG_SIZE:
+                deserialized_outputs = cls.forward_partial(serialized_tensors, ctx, stub)
+                break
+        else:
+            deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
+
+        return tuple(deserialized_outputs)
+
+    @classmethod
+    def forward_partial(
+        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
+    ) -> list[torch.Tensor]:
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
 
         outputs = cls.run_coroutine(
-            stub.rpc_forward(
+            stub.rpc_forward_partial(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
                     as_aiter(*split)
@@ -146,11 +161,20 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        deserialized_outputs = cls.run_coroutine(
+        return cls.run_coroutine(
             gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
 
-        return tuple(deserialized_outputs)
+    @classmethod
+    def forward_oneshot(
+        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
+    ) -> list[torch.Tensor]:
+
+        outputs = cls.run_coroutine(
+            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        )
+
+        return [deserialize_torch_tensor(t) for t in outputs]
 
     @classmethod
     @once_differentiable
@@ -163,10 +187,26 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
 
-        split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
+        size = 0
+        for t in inputs_and_grad_outputs:
+            size += t.element_size() * t.nelement()
+            if size >= DEFAULT_MAX_MSG_SIZE:
+                deserialized_grad_inputs = cls.backward_partial(serialized_tensors, ctx)
+                break
+        else:
+            deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
+
+        return (DUMMY, None, None, None, *deserialized_grad_inputs)
+
+    @classmethod
+    @once_differentiable
+    def backward_partial(
+        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
+    ) -> list[torch.Tensor]:
+        split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
         grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_backward(
+            ctx.stub.rpc_backward_partial(
                 amap_in_executor(
                     lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
                     as_aiter(*split)
@@ -174,7 +214,17 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        deserialized_grad_inputs = cls.run_coroutine(
+        return cls.run_coroutine(
             gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
         )
-        return (DUMMY, None, None, None, *deserialized_grad_inputs)
+
+    @classmethod
+    @once_differentiable
+    def backward_oneshot(
+        cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
+    ) -> list[torch.Tensor]:
+        grad_inputs = cls.run_coroutine(
+            ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+        )
+
+        return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 22 - 0
hivemind/moe/server/connection_handler.py

@@ -90,6 +90,17 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         ]
 
     async def rpc_forward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(
+                inputs, expert.forward_pool, expert.outputs_schema
+            )
+        )
+
+    async def rpc_forward_partial(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         uid, inputs = await self._gather_inputs(requests, context)
@@ -103,6 +114,17 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             yield runtime_pb2.ExpertResponse(tensors=[part, ])
 
     async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        expert = self.experts[request.uid]
+        return runtime_pb2.ExpertResponse(
+            tensors=await self._process_inputs(
+                inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema
+            )
+        )
+
+    async def rpc_backward_partial(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         uid, inputs_and_grads = await self._gather_inputs(requests, context)