|
@@ -36,7 +36,6 @@ class Server(threading.Thread):
|
|
|
|
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
-
|
|
|
|
prefix: Optional[str],
|
|
prefix: Optional[str],
|
|
converted_model_name_or_path: str,
|
|
converted_model_name_or_path: str,
|
|
throughput: Union[float, str],
|
|
throughput: Union[float, str],
|
|
@@ -59,8 +58,8 @@ class Server(threading.Thread):
|
|
expiration: Optional[float] = None,
|
|
expiration: Optional[float] = None,
|
|
prefetch_batches: int = 1,
|
|
prefetch_batches: int = 1,
|
|
sender_threads: int = 1,
|
|
sender_threads: int = 1,
|
|
- max_block_selection_delay: float = 1,
|
|
|
|
- max_balance_check_period: float = 600,
|
|
|
|
|
|
+ mean_block_selection_delay: float = 0.5,
|
|
|
|
+ mean_balance_check_period: float = 300,
|
|
use_auth_token: Optional[str] = None,
|
|
use_auth_token: Optional[str] = None,
|
|
load_in_8bit: bool = False,
|
|
load_in_8bit: bool = False,
|
|
*,
|
|
*,
|
|
@@ -134,7 +133,8 @@ class Server(threading.Thread):
|
|
raise
|
|
raise
|
|
block_indices = range(first_block_index, last_block_index)
|
|
block_indices = range(first_block_index, last_block_index)
|
|
self.block_indices, self.num_blocks = block_indices, num_blocks
|
|
self.block_indices, self.num_blocks = block_indices, num_blocks
|
|
- self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period
|
|
|
|
|
|
+ self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
|
+ self.mean_balance_check_period = mean_balance_check_period
|
|
|
|
|
|
self.stop = threading.Event()
|
|
self.stop = threading.Event()
|
|
if start:
|
|
if start:
|
|
@@ -172,7 +172,7 @@ class Server(threading.Thread):
|
|
self.module_container.ready.wait()
|
|
self.module_container.ready.wait()
|
|
|
|
|
|
while True:
|
|
while True:
|
|
- timeout = random.random() * self.max_balance_check_period
|
|
|
|
|
|
+ timeout = random.random() * 2 * self.mean_balance_check_period
|
|
if self.stop.wait(timeout):
|
|
if self.stop.wait(timeout):
|
|
return
|
|
return
|
|
if self._should_choose_other_blocks():
|
|
if self._should_choose_other_blocks():
|
|
@@ -186,7 +186,7 @@ class Server(threading.Thread):
|
|
|
|
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
- time.sleep(random.random() * self.max_block_selection_delay)
|
|
|
|
|
|
+ time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
|
|
|
|
|
assert self.num_blocks is not None
|
|
assert self.num_blocks is not None
|
|
uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|