|
@@ -200,17 +200,17 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
- loss = None
|
|
|
- if closure is not None:
|
|
|
- with torch.enable_grad():
|
|
|
- loss = closure()
|
|
|
-
|
|
|
if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
|
|
|
raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
|
|
|
if self.batch_size_per_step is None and batch_size is None:
|
|
|
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
|
|
|
batch_size = batch_size if batch_size is not None else self.batch_size_per_step
|
|
|
|
|
|
+ loss = None
|
|
|
+ if closure is not None:
|
|
|
+ with torch.enable_grad():
|
|
|
+ loss = closure()
|
|
|
+
|
|
|
if self.should_load_state_from_peers:
|
|
|
self.load_state_from_peers()
|
|
|
return loss
|