Selaa lähdekoodia

Fix bugs in centered_clip()

Alexander Borzunov 3 vuotta sitten
vanhempi
commit
e9f3288f50
1 muutettua tiedostoa jossa 3 lisäystä ja 2 poistoa
  1. 3 2
      hivemind/averaging/accumulators.py

+ 3 - 2
hivemind/averaging/accumulators.py

@@ -96,13 +96,14 @@ def centered_clip(
             coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
 
             if stop_delta is not None:
-                prev_diff = result[...] = diff[0]  # Reuse memory from `result`
+                result[...] = diff[0]  # Reuse memory from `result`
+                prev_diff = result
 
             # We only need to update `diff` (not `result`) between iterations
             diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
 
             if stop_delta is not None:
-                delta = prev_diff.sub_(diff[0]).max()
+                delta = prev_diff.sub_(diff[0]).abs().max()
                 if delta < stop_delta:
                     break
         torch.sub(input_tensors[0], diff[0], out=result)