瀏覽代碼

fix use_local_updates in optimizer (#468)

- fix use_local_updates in Optimizer
- add test cases to ensure that it works
- ensure that grad_compression is always used

Co-authored-by: artek0chumak <artek.chumak@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 年之前
父節點
當前提交
35851c8ce9
共有 2 個文件被更改,包括 48 次插入7 次删除
  1. 8 6
      hivemind/optim/optimizer.py
  2. 40 1
      tests/test_optimizer.py

+ 8 - 6
hivemind/optim/optimizer.py

@@ -189,7 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
-        grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
+        grad_averager_factory: Optional[GradientAveragerFactory] = None,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
@@ -262,9 +262,9 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             **averager_opts or {},
         )
-        if grad_averager_factory is not None and not use_local_updates:
+        if not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
+                grad_averager_factory, reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression
             )
         else:
             self.grad_averager = None
@@ -297,9 +297,10 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = grad_averager_factory(
+        factory = factory if factory is not None else GradientAverager
+        grad_averager = factory(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -691,7 +692,8 @@ class Optimizer(torch.optim.Optimizer):
             while True:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
-                    self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    if self.grad_averager is not None:
+                        self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                 except KeyboardInterrupt:
                     raise

+ 40 - 1
tests/test_optimizer.py

@@ -297,12 +297,45 @@ def test_progress_tracker():
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
+    # fmt: off
+    [
+        (False, False, False, False, False),
+        (False, True, False, False, False),
+        (False, True, True, True, False),
+        (False, False, False, False, True),
+        (False, True, True, True, True),
+        (False, True, True, False, True),
+        (True, False, False, False, False),
+        (True, True, False, False, False,),
+    ],
+    # fmt: on
+)
 def test_optimizer(
+    use_local_updates: bool,
+    delay_state_averaging: bool,
+    delay_optimizer_step: bool,
+    delay_grad_averaging: bool,
+    reuse_grad_buffers: bool,
+):
+    _test_optimizer(
+        use_local_updates=use_local_updates,
+        delay_state_averaging=delay_state_averaging,
+        delay_grad_averaging=delay_grad_averaging,
+        delay_optimizer_step=delay_optimizer_step,
+        reuse_grad_buffers=reuse_grad_buffers,
+    )
+
+
+def _test_optimizer(
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
     total_epochs: int = 3,
+    use_local_updates: bool = False,
     reuse_grad_buffers: bool = True,
+    delay_state_averaging: bool = True,
     delay_grad_averaging: bool = True,
     delay_optimizer_step: bool = True,
     average_state_every: int = 1,
@@ -330,9 +363,11 @@ def test_optimizer(
             dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
             tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
             averager_opts=dict(request_timeout=0.5),
+            use_local_updates=use_local_updates,
             matchmaking_time=1.0,
             averaging_timeout=5.0,
             reuse_grad_buffers=reuse_grad_buffers,
+            delay_state_averaging=delay_state_averaging,
             delay_grad_averaging=delay_grad_averaging,
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
@@ -391,6 +426,10 @@ def test_optimizer(
     assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
 
     assert not optimizer.state_averager.is_alive()
-    assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
+    if not use_local_updates:
+        assert not optimizer.grad_averager.is_alive()
+    else:
+        assert optimizer.grad_averager is None
+
     assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()