|
@@ -240,17 +240,17 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
|
outputs_schema = requested_backends[-1].outputs_schema
|
|
|
|
|
|
- if metadata.get("output_compressions") is not None:
|
|
|
- assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
- output_compressions = tuple(metadata["output_compressions"])
|
|
|
- assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
|
|
|
- assert len(output_compressions) == 1, f"output_compression tuple should have 1 element"
|
|
|
+ if metadata.get("output_compression") is not None:
|
|
|
+ assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
+ output_compression = tuple(metadata["output_compression"])
|
|
|
+ assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
|
|
|
+ assert len(output_compression) == 1, f"output_compression tuple should have 1 element"
|
|
|
else:
|
|
|
- output_compressions = tuple(tensor.compression for tensor in outputs_schema)
|
|
|
+ output_compression = tuple(tensor.compression for tensor in outputs_schema)
|
|
|
|
|
|
return [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
- for result, proto, compression in zip([hidden_states], outputs_schema, output_compressions)
|
|
|
+ for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)
|
|
|
]
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
@@ -308,17 +308,17 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
|
|
|
) # TODO generalize
|
|
|
|
|
|
- if metadata.get("output_compressions") is not None:
|
|
|
- assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
- output_compressions = tuple(metadata["output_compressions"])
|
|
|
- assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
|
|
|
- assert len(output_compressions) == len(grads), f"output_compression should have {len(grads)} elements"
|
|
|
+ if metadata.get("output_compression") is not None:
|
|
|
+ assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
|
|
|
+ output_compression = tuple(metadata["output_compression"])
|
|
|
+ assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
|
|
|
+ assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
|
|
|
else:
|
|
|
- output_compressions = tuple(tensor.compression for tensor in flat_grads_schema)
|
|
|
+ output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
|
|
|
|
|
|
return [
|
|
|
serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
|
|
|
- for result, proto, compression in zip(grads, flat_grads_schema, output_compressions)
|
|
|
+ for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
|
|
|
]
|
|
|
|
|
|
def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
|