|
@@ -297,12 +297,45 @@ def test_progress_tracker():
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
+@pytest.mark.parametrize(
|
|
|
+ "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
|
|
|
+ # fmt: off
|
|
|
+ [
|
|
|
+ (False, False, False, False, False),
|
|
|
+ (False, True, False, False, False),
|
|
|
+ (False, True, True, True, False),
|
|
|
+ (False, False, False, False, True),
|
|
|
+ (False, True, True, True, True),
|
|
|
+ (False, True, True, False, True),
|
|
|
+ (True, False, False, False, False),
|
|
|
+ (True, True, False, False, False,),
|
|
|
+ ],
|
|
|
+ # fmt: on
|
|
|
+)
|
|
|
def test_optimizer(
|
|
|
+ use_local_updates: bool,
|
|
|
+ delay_state_averaging: bool,
|
|
|
+ delay_optimizer_step: bool,
|
|
|
+ delay_grad_averaging: bool,
|
|
|
+ reuse_grad_buffers: bool,
|
|
|
+):
|
|
|
+ _test_optimizer(
|
|
|
+ use_local_updates=use_local_updates,
|
|
|
+ delay_state_averaging=delay_state_averaging,
|
|
|
+ delay_grad_averaging=delay_grad_averaging,
|
|
|
+ delay_optimizer_step=delay_optimizer_step,
|
|
|
+ reuse_grad_buffers=reuse_grad_buffers,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _test_optimizer(
|
|
|
num_peers: int = 1,
|
|
|
num_clients: int = 0,
|
|
|
target_batch_size: int = 32,
|
|
|
total_epochs: int = 3,
|
|
|
+ use_local_updates: bool = False,
|
|
|
reuse_grad_buffers: bool = True,
|
|
|
+ delay_state_averaging: bool = True,
|
|
|
delay_grad_averaging: bool = True,
|
|
|
delay_optimizer_step: bool = True,
|
|
|
average_state_every: int = 1,
|
|
@@ -330,9 +363,11 @@ def test_optimizer(
|
|
|
dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
|
|
|
tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
|
|
|
averager_opts=dict(request_timeout=0.5),
|
|
|
+ use_local_updates=use_local_updates,
|
|
|
matchmaking_time=1.0,
|
|
|
averaging_timeout=5.0,
|
|
|
reuse_grad_buffers=reuse_grad_buffers,
|
|
|
+ delay_state_averaging=delay_state_averaging,
|
|
|
delay_grad_averaging=delay_grad_averaging,
|
|
|
delay_optimizer_step=delay_optimizer_step,
|
|
|
average_state_every=average_state_every,
|
|
@@ -391,6 +426,10 @@ def test_optimizer(
|
|
|
assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
|
|
|
|
|
|
assert not optimizer.state_averager.is_alive()
|
|
|
- assert not optimizer.grad_averager.is_alive()
|
|
|
assert not optimizer.tracker.is_alive()
|
|
|
+ if not use_local_updates:
|
|
|
+ assert not optimizer.grad_averager.is_alive()
|
|
|
+ else:
|
|
|
+ assert optimizer.grad_averager is None
|
|
|
+
|
|
|
assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()
|