|
@@ -60,6 +60,7 @@ class Server(threading.Thread):
|
|
|
prefetch_batches: int = 1,
|
|
|
sender_threads: int = 1,
|
|
|
max_block_selection_delay: float = 1,
|
|
|
+ max_balance_check_period: float = 600,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
load_in_8bit: bool = False,
|
|
|
*,
|
|
@@ -94,8 +95,6 @@ class Server(threading.Thread):
|
|
|
logger.info(f"Automatic dht prefix: {prefix}")
|
|
|
self.prefix = prefix
|
|
|
|
|
|
- assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
-
|
|
|
if expiration is None:
|
|
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
|
self.expiration = expiration
|
|
@@ -125,6 +124,7 @@ class Server(threading.Thread):
|
|
|
revision=revision,
|
|
|
)
|
|
|
|
|
|
+ assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
@@ -133,51 +133,68 @@ class Server(threading.Thread):
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
raise
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
- else:
|
|
|
- # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
- # this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
- time.sleep(random.random() * max_block_selection_delay)
|
|
|
-
|
|
|
- assert num_blocks is not None
|
|
|
- uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
- module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
|
- block_indices = choose_best_blocks(num_blocks, module_infos)
|
|
|
- self.block_indices = block_indices
|
|
|
+ self.block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
+ self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
if start:
|
|
|
self.start()
|
|
|
|
|
|
def run(self):
|
|
|
- self.module_container = ModuleContainer.create(
|
|
|
- dht=self.dht,
|
|
|
- prefix=self.prefix,
|
|
|
- converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
- block_config=self.block_config,
|
|
|
- memory_cache=self.memory_cache,
|
|
|
- throughput=self.throughput,
|
|
|
- block_indices=self.block_indices,
|
|
|
- num_handlers=self.num_handlers,
|
|
|
- min_batch_size=self.min_batch_size,
|
|
|
- max_batch_size=self.max_batch_size,
|
|
|
- inference_max_length=self.inference_max_length,
|
|
|
- torch_dtype=self.torch_dtype,
|
|
|
- cache_dir=self.cache_dir,
|
|
|
- device=self.device,
|
|
|
- compression=self.compression,
|
|
|
- stats_report_interval=self.stats_report_interval,
|
|
|
- update_period=self.update_period,
|
|
|
- expiration=self.expiration,
|
|
|
- prefetch_batches=self.prefetch_batches,
|
|
|
- sender_threads=self.sender_threads,
|
|
|
- use_auth_token=self.use_auth_token,
|
|
|
- load_in_8bit=self.load_in_8bit,
|
|
|
- start=True,
|
|
|
- )
|
|
|
- try:
|
|
|
- self.stop.wait()
|
|
|
- finally:
|
|
|
- self.module_container.shutdown()
|
|
|
+ while True:
|
|
|
+ block_indices = self._choose_blocks()
|
|
|
+ self.module_container = ModuleContainer.create(
|
|
|
+ dht=self.dht,
|
|
|
+ prefix=self.prefix,
|
|
|
+ converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
+ block_config=self.block_config,
|
|
|
+ memory_cache=self.memory_cache,
|
|
|
+ throughput=self.throughput,
|
|
|
+ block_indices=block_indices,
|
|
|
+ num_handlers=self.num_handlers,
|
|
|
+ min_batch_size=self.min_batch_size,
|
|
|
+ max_batch_size=self.max_batch_size,
|
|
|
+ inference_max_length=self.inference_max_length,
|
|
|
+ torch_dtype=self.torch_dtype,
|
|
|
+ cache_dir=self.cache_dir,
|
|
|
+ device=self.device,
|
|
|
+ compression=self.compression,
|
|
|
+ stats_report_interval=self.stats_report_interval,
|
|
|
+ update_period=self.update_period,
|
|
|
+ expiration=self.expiration,
|
|
|
+ prefetch_batches=self.prefetch_batches,
|
|
|
+ sender_threads=self.sender_threads,
|
|
|
+ use_auth_token=self.use_auth_token,
|
|
|
+ load_in_8bit=self.load_in_8bit,
|
|
|
+ start=True,
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ self.module_container.ready.wait()
|
|
|
+
|
|
|
+ while True:
|
|
|
+ timeout = random.random() * self.max_balance_check_period
|
|
|
+ if self.stop.wait(timeout):
|
|
|
+ return
|
|
|
+ if self._should_choose_other_blocks():
|
|
|
+ break # Stop serving this set of modules
|
|
|
+ finally:
|
|
|
+ self.module_container.shutdown()
|
|
|
+
|
|
|
+ def _choose_blocks(self) -> List[int]:
|
|
|
+ if self.block_indices is not None:
|
|
|
+ return self.block_indices
|
|
|
+
|
|
|
+ # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
+ # this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
+ time.sleep(random.random() * self.max_block_selection_delay)
|
|
|
+
|
|
|
+ assert self.num_blocks is not None
|
|
|
+ uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
+ module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
|
+ return choose_best_blocks(self.num_blocks, module_infos)
|
|
|
+
|
|
|
+ def _should_choose_other_blocks(self) -> bool:
|
|
|
+ return False
|
|
|
|
|
|
def shutdown(self):
|
|
|
self.stop.set()
|