Răsfoiți Sursa

Add server-side gradient accumulation

Max Ryabinin 3 ani în urmă
părinte
comite
b48220577e
2 a modificat fișierele cu 10 adăugiri și 5 ștergeri
  1. 1 0
      hivemind/moe/server/__init__.py
  2. 9 5
      hivemind/moe/server/expert_backend.py

+ 1 - 0
hivemind/moe/server/__init__.py

@@ -331,6 +331,7 @@ class Server(threading.Thread):
                     clip_grad_norm=clip_grad_norm,
                     min_batch_size=min_batch_size,
                     max_batch_size=max_batch_size,
+                    target_batch_size=averaging_target_batch_size,
                 )
 
         if checkpoint_dir is not None:

+ 9 - 5
hivemind/moe/server/expert_backend.py

@@ -57,6 +57,7 @@ class ExpertBackend:
         num_warmup_steps: int = None,
         num_total_steps: int = None,
         clip_grad_norm: float = None,
+        target_batch_size: int = None,
         **kwargs,
     ):
         super().__init__()
@@ -98,6 +99,7 @@ class ExpertBackend:
 
         self.update_count = 0
         self.examples_processed = 0
+        self.target_batch_size = target_batch_size
 
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
@@ -182,21 +184,23 @@ class ExpertBackend:
         """
         Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
+        self.examples_processed += batch_size
+
         if self.clip_grad_norm is not None:
             torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
 
         if isinstance(self.optimizer, hivemind.CollaborativeOptimizer):
             self.optimizer.step(batch_size)
         else:
-            self.optimizer.step()
-            self.optimizer.zero_grad()
+            if self.target_batch_size is None or self.examples_processed % self.target_batch_size == 0:
+                self.optimizer.step()
+                self.optimizer.zero_grad()
 
-            if self.scheduler is not None:
-                self.scheduler.step()
+                if self.scheduler is not None:
+                    self.scheduler.step()
 
         # TODO update_count is not always incremented if CollaborativeOptimizer is used
         self.update_count += 1
-        self.examples_processed += batch_size
 
     def get_stats(self) -> Dict:
         """