Răsfoiți Sursa

write an essay on why that matters.

justheuristic 3 ani în urmă
părinte
comite
38d1a60efe
1 a modificat fișierele cu 5 adăugiri și 0 ștergeri
  1. 5 0
      hivemind/optim/grad_scaler.py

+ 5 - 0
hivemind/optim/grad_scaler.py

@@ -54,6 +54,11 @@ class GradScaler(TorchGradScaler):
             if self._is_running_global_step:
                 super().unscale_(optimizer)
                 self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
+                # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
+                # gradients (i.e. after global unscale but before global step). In that case, the next call to .update
+                # on user side would reset *all* optimizer states and cause .step to unscale gradients the second time.
+                # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
                 return True
             else:
                 self._check_inf_per_device(optimizer)