瀏覽代碼

Fix bugs in centered_clip()

Alexander Borzunov 3 年之前
父節點
當前提交
e9f3288f50
共有 1 個文件被更改,包括 3 次插入2 次删除
  1. 3 2
      hivemind/averaging/accumulators.py

+ 3 - 2
hivemind/averaging/accumulators.py

@@ -96,13 +96,14 @@ def centered_clip(
             coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
 
             if stop_delta is not None:
-                prev_diff = result[...] = diff[0]  # Reuse memory from `result`
+                result[...] = diff[0]  # Reuse memory from `result`
+                prev_diff = result
 
             # We only need to update `diff` (not `result`) between iterations
             diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
 
             if stop_delta is not None:
-                delta = prev_diff.sub_(diff[0]).max()
+                delta = prev_diff.sub_(diff[0]).abs().max()
                 if delta < stop_delta:
                     break
         torch.sub(input_tensors[0], diff[0], out=result)