5
0
Your Name 1 жил өмнө
parent
commit
9e29140bb0

+ 2 - 2
src/petals/client/remote_forward_backward.py

@@ -144,8 +144,8 @@ async def run_remote_backward(
             for tensor, compression in zip(flat_tensors, codecs)
         )
     )
-    for tensor, serialized_tensor in zip(flat_tensors, serialized_tensors):
-        serialized_tensor.requires_grad = tensor.requires_grad
+    for tensor, serialized in zip(flat_tensors, serialized_tensors):
+        serialized.requires_grad = tensor.requires_grad  # see https://github.com/learning-at-home/hivemind/pull/591
 
     size = sum(t.element_size() * t.nelement() for t in flat_tensors)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary