|
@@ -1,5 +1,4 @@
|
|
|
from __future__ import annotations
|
|
|
-import warnings
|
|
|
from dataclasses import dataclass
|
|
|
from threading import Thread, Lock, Event
|
|
|
from typing import Optional, Iterator
|
|
@@ -190,7 +189,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
self.averager.local_step += 1
|
|
|
|
|
|
self.opt.step()
|
|
|
- self.opt.zero_grad()
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
+ self.opt.zero_grad()
|
|
|
self.reset_accumulated_grads_()
|
|
|
self.local_samples_accumulated = self.local_steps_accumulated = 0
|
|
|
self.collaboration_state.register_step()
|
|
@@ -230,6 +230,8 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
+ return
|
|
|
for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
|
|
|
grad_buf[...] = grad_acc.to(grad_buf.device)
|
|
|
if scale_by is not None:
|
|
@@ -237,7 +239,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def reset_accumulated_grads_(self):
|
|
|
- for grad_buf in self._grad_buffers():
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
+ return
|
|
|
+ for grad_buf in self.accumulated_grads():
|
|
|
grad_buf.zero_()
|
|
|
|
|
|
def report_training_progress(self):
|
|
@@ -316,7 +320,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
|
|
|
next_fetch_time=current_time + time_to_next_fetch)
|
|
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
|
- warnings.warn("CollaborativeOptimizer.zero_grad is a no-op and doesn't need to be called")
|
|
|
+ if self.reuse_grad_buffers:
|
|
|
+ raise ValueError(f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
|
|
|
+ f"call zero_grad manually. Gradients will be refreshed internally.")
|
|
|
+ return self.opt.zero_grad(*args, **kwargs)
|
|
|
|
|
|
@staticmethod
|
|
|
def is_valid_peer_state(state):
|