|
@@ -1,7 +1,9 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import multiprocessing as mp
|
|
|
+import random
|
|
|
import threading
|
|
|
+import time
|
|
|
from typing import Dict, Literal, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
@@ -111,6 +113,7 @@ class Server(threading.Thread):
|
|
|
custom_module_path=None,
|
|
|
update_period: float = 30,
|
|
|
expiration: Optional[float] = None,
|
|
|
+ max_block_selection_delay: float = 1,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
*,
|
|
|
start: bool,
|
|
@@ -158,6 +161,10 @@ class Server(threading.Thread):
|
|
|
raise
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
else:
|
|
|
+ # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
+ # this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
+ time.sleep(random.random() * max_block_selection_delay)
|
|
|
+
|
|
|
assert num_blocks is not None
|
|
|
uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
|
|
|
module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
|