|
@@ -217,13 +217,14 @@ def test_progress_tracker():
|
|
|
tracker.report_local_progress(local_epoch, samples_accumulated)
|
|
|
|
|
|
if tracker.ready_to_update_epoch:
|
|
|
+ if index == 4 and local_epoch >= 4:
|
|
|
+ time.sleep(0.5)
|
|
|
+ break
|
|
|
+
|
|
|
with tracker.pause_updates():
|
|
|
local_epoch = tracker.update_epoch(local_epoch + 1)
|
|
|
samples_accumulated = 0
|
|
|
|
|
|
- if index == 4 and local_epoch >= 5:
|
|
|
- time.sleep(0.5)
|
|
|
- break
|
|
|
|
|
|
emas[index] = tracker.performance_ema.samples_per_second
|
|
|
tracker.shutdown()
|
|
@@ -249,16 +250,19 @@ def test_progress_tracker():
|
|
|
)
|
|
|
barrier.wait()
|
|
|
|
|
|
- current_step = 0
|
|
|
+ local_epoch = 0
|
|
|
last_timestamp = hivemind.get_dht_time()
|
|
|
step_time_deltas = []
|
|
|
|
|
|
- while current_step < 6:
|
|
|
+ while local_epoch < 6:
|
|
|
time.sleep(0.1)
|
|
|
- if tracker.global_progress.epoch > current_step:
|
|
|
+
|
|
|
+ if tracker.ready_to_update_epoch:
|
|
|
+ with tracker.pause_updates():
|
|
|
+ local_epoch = tracker.update_epoch(local_epoch + 1)
|
|
|
+
|
|
|
time_delta = hivemind.get_dht_time() - last_timestamp
|
|
|
- current_step = tracker.global_progress.epoch
|
|
|
- if current_step == 2:
|
|
|
+ if local_epoch == 2:
|
|
|
delayed_start_evt.set()
|
|
|
|
|
|
last_timestamp = hivemind.get_dht_time()
|
|
@@ -273,6 +277,7 @@ def test_progress_tracker():
|
|
|
assert not tracker.is_alive()
|
|
|
|
|
|
mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
|
|
|
+ print(step_time_deltas, mean_step_time)
|
|
|
for i in (0, 1, 5): # Without the 4th worker (the fastest one)
|
|
|
assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
|
|
|
for i in (2, 3, 4): # With the 4th worker
|