Selaa lähdekoodia

better zero_grad behavior in CollaborativeOptimizer (#221)

justheuristic 4 vuotta sitten
vanhempi
commit
b906ae94ed
1 muutettua tiedostoa jossa 11 lisäystä ja 4 poistoa
  1. 11 4
      hivemind/client/optim/collaborative.py

+ 11 - 4
hivemind/client/optim/collaborative.py

@@ -1,5 +1,4 @@
 from __future__ import annotations
 from __future__ import annotations
-import warnings
 from dataclasses import dataclass
 from dataclasses import dataclass
 from threading import Thread, Lock, Event
 from threading import Thread, Lock, Event
 from typing import Optional, Iterator
 from typing import Optional, Iterator
@@ -190,7 +189,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                 self.averager.local_step += 1
                 self.averager.local_step += 1
 
 
             self.opt.step()
             self.opt.step()
-            self.opt.zero_grad()
+            if self.reuse_grad_buffers:
+                self.opt.zero_grad()
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.collaboration_state.register_step()
             self.collaboration_state.register_step()
@@ -230,6 +230,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
     def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
+        if self.reuse_grad_buffers:
+            return
         for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
         for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
             grad_buf[...] = grad_acc.to(grad_buf.device)
             grad_buf[...] = grad_acc.to(grad_buf.device)
             if scale_by is not None:
             if scale_by is not None:
@@ -237,7 +239,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def reset_accumulated_grads_(self):
     def reset_accumulated_grads_(self):
-        for grad_buf in self._grad_buffers():
+        if self.reuse_grad_buffers:
+            return
+        for grad_buf in self.accumulated_grads():
             grad_buf.zero_()
             grad_buf.zero_()
 
 
     def report_training_progress(self):
     def report_training_progress(self):
@@ -316,7 +320,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             next_fetch_time=current_time + time_to_next_fetch)
             next_fetch_time=current_time + time_to_next_fetch)
 
 
     def zero_grad(self, *args, **kwargs):
     def zero_grad(self, *args, **kwargs):
-        warnings.warn("CollaborativeOptimizer.zero_grad is a no-op and doesn't need to be called")
+        if self.reuse_grad_buffers:
+            raise ValueError(f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                             f"call zero_grad manually. Gradients will be refreshed internally.")
+        return self.opt.zero_grad(*args, **kwargs)
 
 
     @staticmethod
     @staticmethod
     def is_valid_peer_state(state):
     def is_valid_peer_state(state):