Bläddra i källkod

better zero_grad behavior in CollaborativeOptimizer (#221)

justheuristic 4 år sedan
förälder
incheckning
b906ae94ed
1 ändrade filer med 11 tillägg och 4 borttagningar
  1. 11 4
      hivemind/client/optim/collaborative.py

+ 11 - 4
hivemind/client/optim/collaborative.py

@@ -1,5 +1,4 @@
 from __future__ import annotations
-import warnings
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
 from typing import Optional, Iterator
@@ -190,7 +189,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.averager.local_step += 1
 
             self.opt.step()
-            self.opt.zero_grad()
+            if self.reuse_grad_buffers:
+                self.opt.zero_grad()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
@@ -230,6 +230,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
+        if self.reuse_grad_buffers:
+            return
         for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
             grad_buf[...] = grad_acc.to(grad_buf.device)
             if scale_by is not None:
@@ -237,7 +239,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
     @torch.no_grad()
     def reset_accumulated_grads_(self):
-        for grad_buf in self._grad_buffers():
+        if self.reuse_grad_buffers:
+            return
+        for grad_buf in self.accumulated_grads():
             grad_buf.zero_()
 
     def report_training_progress(self):
@@ -316,7 +320,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             next_fetch_time=current_time + time_to_next_fetch)
 
     def zero_grad(self, *args, **kwargs):
-        warnings.warn("CollaborativeOptimizer.zero_grad is a no-op and doesn't need to be called")
+        if self.reuse_grad_buffers:
+            raise ValueError(f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                             f"call zero_grad manually. Gradients will be refreshed internally.")
+        return self.opt.zero_grad(*args, **kwargs)
 
     @staticmethod
     def is_valid_peer_state(state):