Aleksandr Borzunov 3 жил өмнө
parent
commit
2738f8ca39

+ 1 - 1
benchmarks/benchmark_optimizer.py

@@ -77,7 +77,7 @@ def benchmark_optimizer(args: TrainingArguments):
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
 
         optimizer = Optimizer(
-            prefix=args.prefix,
+            run_id=args.prefix,
             target_batch_size=args.target_batch_size,
             batch_size_per_step=batch_size,
             params=model.parameters(),

+ 8 - 8
hivemind/optim/experimental/optimizer.py

@@ -35,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
 
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
-    >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
+    >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, run_id="run_42",
     >>>                          target_batch_size=4096, batch_size_per_step=4)
     >>> while True:
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
@@ -54,7 +54,7 @@ class Optimizer(torch.optim.Optimizer):
       other peers have already made some progress and changed their learning rate accordingly.
 
     :param dht: a running hivemind.DHT instance connected to other peers
-    :param prefix: a unique name of this experiment, used as a common prefix for all DHT keys
+    :param run_id: a unique name of this experiment, used as a common prefix for all DHT keys
     :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
     :param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
@@ -88,7 +88,7 @@ class Optimizer(torch.optim.Optimizer):
         self,
         *,
         dht: DHT,
-        prefix: str,
+        run_id: str,
         target_batch_size: int,
         batch_size_per_step: Optional[int] = None,
         optimizer: Union[TorchOptimizer, OptimizerFactory],
@@ -114,7 +114,7 @@ class Optimizer(torch.optim.Optimizer):
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
 
-        self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
@@ -141,7 +141,7 @@ class Optimizer(torch.optim.Optimizer):
     def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
         return TrainingStateAverager(
             dht=self.dht,
-            prefix=f"{self.prefix}_state_averager",
+            prefix=f"{self.run_id}_state_averager",
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             status_loglevel=self.status_loglevel,
@@ -157,7 +157,7 @@ class Optimizer(torch.optim.Optimizer):
         assert hasattr(self, "state_averager"), "must initialize state averager first"
         grad_averager = GradientAverager(
             dht=self.dht,
-            prefix=f"{self.prefix}_grad_averager",
+            prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
@@ -177,7 +177,7 @@ class Optimizer(torch.optim.Optimizer):
     def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
         return ProgressTracker(
             dht=self.dht,
-            prefix=self.prefix,
+            prefix=self.run_id,
             target_batch_size=target_batch_size,
             client_mode=self.client_mode,
             status_loglevel=self.status_loglevel,
@@ -444,7 +444,7 @@ class Optimizer(torch.optim.Optimizer):
         )
 
     def __repr__(self):
-        return f"{self.__class__.__name__}(prefix={self.prefix}, epoch={self.local_epoch})"
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
 
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")

+ 5 - 5
tests/test_optimizer.py

@@ -290,7 +290,7 @@ def test_progress_tracker():
 def test_optimizer(
     num_peers: int = 1,
     num_clients: int = 0,
-    target_batch_size: int = 64,
+    target_batch_size: int = 32,
     total_epochs: int = 3,
     reuse_grad_buffers: bool = True,
     delay_grad_averaging: bool = True,
@@ -311,7 +311,7 @@ def test_optimizer(
         assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
 
         optimizer = Optimizer(
-            prefix="test_run",
+            run_id="test_run",
             target_batch_size=target_batch_size,
             batch_size_per_step=batch_size,
             params=model.parameters(),
@@ -334,7 +334,7 @@ def test_optimizer(
         prev_time = time.perf_counter()
 
         while optimizer.local_epoch < total_epochs:
-            time.sleep(max(0.0, prev_time + random.gauss(batch_time, 0.1) - time.perf_counter()))
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
             batch = torch.randint(0, len(features), (batch_size,))
 
             loss = F.mse_loss(model(features[batch]), targets[batch])
@@ -377,8 +377,8 @@ def test_optimizer(
     assert isinstance(optimizer, Optimizer)
     assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
     expected_samples_accumulated = target_batch_size * total_epochs
-    assert expected_samples_accumulated <= total_samples_accumulated.value <= 2 * expected_samples_accumulated
-    assert 4 / 0.3 * 0.9 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.1
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
 
     assert not optimizer.state_averager.is_alive()
     assert not optimizer.grad_averager.is_alive()