5
0
justheuristic 2 жил өмнө
parent
commit
8b8d54abc5

+ 1 - 1
src/client/remote_forward_backward.py

@@ -111,7 +111,7 @@ async def run_remote_backward(
     inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
-    metadata: bytes = b"",
+    **kwargs,
 ) -> Sequence[torch.Tensor]:
     """
     Serializes grad outputs and calls "rpc_backward" on a remote server.