|
@@ -81,6 +81,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_header(request)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
+ # Cast inputs to backend dtype
|
|
|
+ hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
+
|
|
|
# Run a chain of requested backends
|
|
|
for backend in requested_backends:
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
@@ -93,7 +96,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
)
|
|
@@ -106,6 +109,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_header_str(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
+ # Cast inputs to backend dtype
|
|
|
+ hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
|
+
|
|
|
# Run a chain of requested backends
|
|
|
for backend in requested_backends:
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
@@ -117,7 +123,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
# Serialize the overall output
|
|
|
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
serialized_output = [
|
|
|
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
]
|
|
|
|
|
@@ -134,6 +140,10 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_header(request)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
+ # Cast inputs & grad outputs to backend dtype
|
|
|
+ inputs = inputs.to(requested_backends[0].dtype)
|
|
|
+ grads = grads.to(requested_backends[-1].dtype)
|
|
|
+
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
inter_inputs = [inputs]
|
|
@@ -154,7 +164,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
# Serialize the overall grad_input and respond
|
|
|
return runtime_pb2.ExpertResponse(
|
|
|
tensors=[
|
|
|
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
]
|
|
|
)
|
|
@@ -162,11 +172,16 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def rpc_backward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
+
|
|
|
uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
|
|
|
inputs, grads = inputs_and_grads
|
|
|
requested_uids = self._check_header_str(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
+ # Cast inputs & grad outputs to backend dtype
|
|
|
+ inputs = inputs.to(requested_backends[0].dtype)
|
|
|
+ grads = grads.to(requested_backends[-1].dtype)
|
|
|
+
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its outputs
|
|
|
inter_inputs = [inputs]
|
|
@@ -186,7 +201,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
# Serialize the overall grad_inputs
|
|
|
serialized_grad_inputs = [
|
|
|
- serialize_torch_tensor(result, proto.compression, allow_inplace=True)
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
|
|
|
]
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|