|
@@ -6,10 +6,13 @@ import argparse
|
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.utils import LOCALHOST, increase_file_limit
|
|
|
+from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+
|
|
|
def sample_tensors(hid_size, num_layers):
|
|
|
tensors = []
|
|
|
for i in range(num_layers):
|
|
@@ -38,8 +41,11 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
|
|
|
peer_tensors = [sample_tensors(hid_size, num_layers)
|
|
|
for _ in range(num_peers)]
|
|
|
processes = {dht_root}
|
|
|
+ lock_stats = threading.Lock()
|
|
|
+ successful_steps = total_steps = 0
|
|
|
|
|
|
def run_averager(index):
|
|
|
+ nonlocal successful_steps, total_steps, lock_stats
|
|
|
dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
|
|
|
initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
|
|
|
start=True)
|
|
@@ -50,11 +56,17 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
|
|
|
averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
|
|
|
processes.update({dht, averager})
|
|
|
|
|
|
- print(end=f'<started {index}>\n', flush=True)
|
|
|
- for _ in range(num_rounds):
|
|
|
- success = averager.step(timeout=round_timeout)
|
|
|
- print(end=('+' if success else '-'), flush=True)
|
|
|
- print(end=f'<finished {index}>\n', flush=True)
|
|
|
+ logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
|
|
|
+ for step in range(num_rounds):
|
|
|
+ try:
|
|
|
+ success = averager.step(timeout=round_timeout) is not None
|
|
|
+ except:
|
|
|
+ success = False
|
|
|
+ with lock_stats:
|
|
|
+ successful_steps += int(success)
|
|
|
+ total_steps += 1
|
|
|
+ logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
|
|
|
+ logger.info(f"Averager {index}: done.")
|
|
|
|
|
|
threads = []
|
|
|
for i in range(num_peers):
|
|
@@ -67,10 +79,8 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
|
|
|
for thread in threads:
|
|
|
thread.join()
|
|
|
|
|
|
- print(f"\ntest run took {time.time() - t:.3f} seconds")
|
|
|
-
|
|
|
- for process in processes:
|
|
|
- process.terminate()
|
|
|
+ logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
|
|
|
+ logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
@@ -80,9 +90,9 @@ if __name__ == "__main__":
|
|
|
parser.add_argument('--num_rounds', type=int, default=5, required=False)
|
|
|
parser.add_argument('--hid_size', type=int, default=256, required=False)
|
|
|
parser.add_argument('--num_layers', type=int, default=3, required=False)
|
|
|
- parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
|
|
|
- parser.add_argument('--round_timeout', type=float, default=30, required=False)
|
|
|
- parser.add_argument('--request_timeout', type=float, default=3, required=False)
|
|
|
+ parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
|
|
|
+ parser.add_argument('--round_timeout', type=float, default=15, required=False)
|
|
|
+ parser.add_argument('--request_timeout', type=float, default=1, required=False)
|
|
|
parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
|
|
|
parser.add_argument('--increase_file_limit', action="store_true")
|
|
|
args = vars(parser.parse_args())
|