|
@@ -17,7 +17,7 @@ from src.data_structures import ModuleUID, RPCInfo
|
|
|
|
|
|
|
|
|
async def run_remote_forward(
|
|
|
- uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b'', **kwargs
|
|
|
+ uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
"""
|
|
|
Serializes input tensors and calls "rpc_forward" on a remote server.
|
|
@@ -89,7 +89,6 @@ async def _forward_unary(
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
-
|
|
|
async def _backward_stream(
|
|
|
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
|
|
|
) -> List[torch.Tensor]:
|
|
@@ -112,7 +111,7 @@ async def run_remote_backward(
|
|
|
inputs: torch.Tensor,
|
|
|
grad_outputs: List[torch.Tensor],
|
|
|
*extra_tensors: torch.Tensor,
|
|
|
- metadata: bytes = b''
|
|
|
+ metadata: bytes = b"",
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
"""
|
|
|
Serializes grad outputs and calls "rpc_backward" on a remote server.
|