|
@@ -180,7 +180,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
prioritizer=self._prioritizer,
|
|
|
points=points,
|
|
|
quant_type=self.quant_type,
|
|
|
- structure=args_structure,
|
|
|
+ args_structure=args_structure,
|
|
|
):
|
|
|
if can_push:
|
|
|
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
|
|
@@ -444,16 +444,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grads = await run_rpc_backward(
|
|
|
+ flat_grads, grads_structure = await run_rpc_backward(
|
|
|
*flat_tensors,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
|
active_adapter=active_adapter,
|
|
|
points=points,
|
|
|
- structure=args_structure,
|
|
|
+ args_structure=args_structure,
|
|
|
)
|
|
|
|
|
|
- return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
|
|
+ serialized_flat_grads = self._serialize_grads(flat_grads, flat_tensors, metadata)
|
|
|
+ serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grads_structure))
|
|
|
+ return runtime_pb2.ExpertResponse(tensors=serialized_flat_grads, metadata=serialized_output_metadata)
|
|
|
|
|
|
async def rpc_backward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
@@ -471,18 +473,20 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
points, (float, int)
|
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grads = await run_rpc_backward(
|
|
|
+ flat_grads, grad_structure = await run_rpc_backward(
|
|
|
*flat_tensors,
|
|
|
requested_backends=requested_backends,
|
|
|
prioritizer=self._prioritizer,
|
|
|
active_adapter=active_adapter,
|
|
|
points=points,
|
|
|
- structure=args_structure,
|
|
|
+ args_structure=args_structure,
|
|
|
)
|
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
|
- for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
|
+ serialized_output_metadata = MSGPackSerializer.dumps(output_metadata)
|
|
|
+ for tensor in self._serialize_grads(flat_grads, requested_backends, metadata):
|
|
|
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata)
|
|
|
+ serialized_output_metadata = None # attach metadata to the first response only
|
|
|
|
|
|
def _get_active_adapter(self, metadata: dict) -> str:
|
|
|
active_adapter = metadata.get("active_adapter", "")
|
|
@@ -492,28 +496,27 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
def _serialize_grads(
|
|
|
self,
|
|
|
- grads: Sequence[torch.Tensor],
|
|
|
- requested_backends: Sequence[TransformerBackend],
|
|
|
- metadata: Dict[str, Any],
|
|
|
+ flat_grads: Sequence[torch.Tensor],
|
|
|
+ flat_inputs: Sequence[runtime_pb2.Tensor],
|
|
|
+ input_metadata: Dict[str, Any],
|
|
|
) -> Sequence[runtime_pb2.Tensor]:
|
|
|
"""Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
|
|
|
+ inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad)
|
|
|
+ assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \
|
|
|
+ f"but backward produced {len(flat_grads)} gradients"
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
- flat_grads_schema = tuple(
|
|
|
- nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
|
|
|
- ) # TODO generalize
|
|
|
-
|
|
|
- if metadata.get("output_compression") is not None:
|
|
|
- assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
- output_compression = tuple(metadata["output_compression"])
|
|
|
+ if input_metadata.get("output_compression") is not None:
|
|
|
+ output_compression = input_metadata["output_compression"]
|
|
|
+ assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list"
|
|
|
assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
|
|
|
- assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
|
|
|
+ assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \
|
|
|
+ f"elements, one for every tensor thar requires grad"
|
|
|
else:
|
|
|
- output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
|
|
|
-
|
|
|
+ output_compression = tuple(runtime_pb2.NONE for _ in flat_grads)
|
|
|
+ output_compression = tuple(output_compression)
|
|
|
return [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
- for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
|
|
|
+ serialize_torch_tensor(result.to(input.dtype), compression, allow_inplace=True)
|
|
|
+ for result, input, compression in zip(flat_grads, inputs_with_grad, output_compression)
|
|
|
]
|
|
|
|
|
|
def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
|