Bladeren bron

add closure support (suggested by @SeanNaren)

justheuristic 3 jaren geleden
bovenliggende
commit
0add448239
1 gewijzigde bestanden met toevoegingen van 5 en 5 verwijderingen
  1. 5 5
      hivemind/optim/experimental/optimizer.py

+ 5 - 5
hivemind/optim/experimental/optimizer.py

@@ -200,17 +200,17 @@ class Optimizer(torch.optim.Optimizer):
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
-        loss = None
-        if closure is not None:
-            with torch.enable_grad():
-                loss = closure()
-
         if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
             raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
         if self.batch_size_per_step is None and batch_size is None:
             raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
         if self.should_load_state_from_peers:
             self.load_state_from_peers()
             return loss