فهرست منبع

add closure support (suggested by @SeanNaren)

justheuristic 3 سال پیش
والد
کامیت
749ab83ea1
1فایلهای تغییر یافته به همراه16 افزوده شده و 5 حذف شده
  1. 16 5
      hivemind/optim/experimental/optimizer.py

+ 16 - 5
hivemind/optim/experimental/optimizer.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import logging
 import os
-from typing import Optional, Union
+from typing import Optional, Union, Callable
 
 import torch
 
@@ -187,14 +187,24 @@ class Optimizer(torch.optim.Optimizer):
         """If true, peer will discard local progress and attempt to download state from peers."""
         return self.local_epoch < self.tracker.global_epoch - self.epoch_tolerance
 
-    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
+    def step(self,
+             closure: Optional[Callable[[], torch.Tensor]] = None,
+             batch_size: Optional[int] = None,
+             grad_scaler: Optional[HivemindGradScaler] = None,
+             **kwargs):
         """
         Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
 
+        :param closure: A closure that reevaluates the model and returns the loss
         :param batch_size: optional override for batch_size_per_step from init
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
         if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
             raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
         if self.batch_size_per_step is None and batch_size is None:
@@ -203,13 +213,13 @@ class Optimizer(torch.optim.Optimizer):
 
         if self.should_load_state_from_peers:
             self.load_state_from_peers()
-            return
+            return loss
 
         if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
             self.grad_averager.reset_accumulated_grads_()
-            return
+            return loss
 
         self.grad_averager.accumulate_grads_(batch_size)
         self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
@@ -226,7 +236,7 @@ class Optimizer(torch.optim.Optimizer):
                 self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
 
         if not self.tracker.ready_to_update_epoch:
-            return
+            return loss
 
         with self.tracker.pause_updates():
             logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
@@ -275,6 +285,7 @@ class Optimizer(torch.optim.Optimizer):
             self.grad_averager.reset_accumulated_grads_()
             self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
             logger.log(self.status_loglevel, f"Optimizer step done! Beginning next epoch {self.local_epoch}.")
+        return loss
 
     def step_aux(self, **kwargs):
         """