Pārlūkot izejas kodu

cover edge case

justheuristic 2 gadi atpakaļ
vecāks
revīzija
b8c88e30f2
1 mainītis faili ar 2 papildinājumiem un 3 dzēšanām
  1. 2 3
      src/client/remote_forward_backward.py

+ 2 - 3
src/client/remote_forward_backward.py

@@ -17,7 +17,7 @@ from src.data_structures import ModuleUID, RPCInfo
 
 
 async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b'', **kwargs
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
@@ -89,7 +89,6 @@ async def _forward_unary(
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
-
 async def _backward_stream(
     uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
 ) -> List[torch.Tensor]:
@@ -112,7 +111,7 @@ async def run_remote_backward(
     inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
-    metadata: bytes = b''
+    metadata: bytes = b"",
 ) -> Sequence[torch.Tensor]:
     """
     Serializes grad outputs and calls "rpc_backward" on a remote server.