|
@@ -96,13 +96,14 @@ def centered_clip(
|
|
|
coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
|
|
|
|
|
|
if stop_delta is not None:
|
|
|
- prev_diff = result[...] = diff[0] # Reuse memory from `result`
|
|
|
+ result[...] = diff[0] # Reuse memory from `result`
|
|
|
+ prev_diff = result
|
|
|
|
|
|
# We only need to update `diff` (not `result`) between iterations
|
|
|
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
|
|
|
|
|
|
if stop_delta is not None:
|
|
|
- delta = prev_diff.sub_(diff[0]).max()
|
|
|
+ delta = prev_diff.sub_(diff[0]).abs().max()
|
|
|
if delta < stop_delta:
|
|
|
break
|
|
|
torch.sub(input_tensors[0], diff[0], out=result)
|