|
@@ -135,10 +135,25 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
]
|
|
|
|
|
|
+ size = 0
|
|
|
+ for t in inputs:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ deserialized_outputs = cls.forward_partial(serialized_tensors, ctx, stub)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
|
|
|
+
|
|
|
+ return tuple(deserialized_outputs)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def forward_partial(
|
|
|
+ cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
|
|
|
+ ) -> list[torch.Tensor]:
|
|
|
split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
|
|
|
|
|
|
outputs = cls.run_coroutine(
|
|
|
- stub.rpc_forward(
|
|
|
+ stub.rpc_forward_partial(
|
|
|
amap_in_executor(
|
|
|
lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
|
|
|
as_aiter(*split)
|
|
@@ -146,11 +161,20 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- deserialized_outputs = cls.run_coroutine(
|
|
|
+ return cls.run_coroutine(
|
|
|
gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
)
|
|
|
|
|
|
- return tuple(deserialized_outputs)
|
|
|
+ @classmethod
|
|
|
+ def forward_oneshot(
|
|
|
+ cls, serialized_tensors: list[runtime_pb2.Tensor], ctx, stub
|
|
|
+ ) -> list[torch.Tensor]:
|
|
|
+
|
|
|
+ outputs = cls.run_coroutine(
|
|
|
+ stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
+ )
|
|
|
+
|
|
|
+ return [deserialize_torch_tensor(t) for t in outputs]
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
@@ -163,10 +187,26 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
]
|
|
|
|
|
|
- split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
|
|
|
+ size = 0
|
|
|
+ for t in inputs_and_grad_outputs:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ deserialized_grad_inputs = cls.backward_partial(serialized_tensors, ctx)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
|
|
|
+
|
|
|
+ return (DUMMY, None, None, None, *deserialized_grad_inputs)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ @once_differentiable
|
|
|
+ def backward_partial(
|
|
|
+ cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
|
|
|
+ ) -> list[torch.Tensor]:
|
|
|
+ split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(
|
|
|
+ ctx.stub.rpc_backward_partial(
|
|
|
amap_in_executor(
|
|
|
lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
|
|
|
as_aiter(*split)
|
|
@@ -174,7 +214,17 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- deserialized_grad_inputs = cls.run_coroutine(
|
|
|
+ return cls.run_coroutine(
|
|
|
gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
)
|
|
|
- return (DUMMY, None, None, None, *deserialized_grad_inputs)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ @once_differentiable
|
|
|
+ def backward_oneshot(
|
|
|
+ cls, serialized_tensors: list[runtime_pb2.Tensor], ctx
|
|
|
+ ) -> list[torch.Tensor]:
|
|
|
+ grad_inputs = cls.run_coroutine(
|
|
|
+ ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
+ )
|
|
|
+
|
|
|
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|