|
@@ -41,7 +41,6 @@ class Server(threading.Thread):
|
|
|
throughput: Union[float, str],
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
- allow_rebalancing: bool = True,
|
|
|
num_handlers: int = 8,
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: int = 4096,
|
|
@@ -60,7 +59,7 @@ class Server(threading.Thread):
|
|
|
prefetch_batches: int = 1,
|
|
|
sender_threads: int = 1,
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
- mean_balance_check_period: float = 300, # TODO:
|
|
|
+ mean_balance_check_period: float = 300,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
load_in_8bit: bool = False,
|
|
|
*,
|
|
@@ -133,16 +132,17 @@ class Server(threading.Thread):
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
raise
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
- self.block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
- self.allow_rebalancing = allow_rebalancing
|
|
|
+ self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
+ self._module_infos = None
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
if start:
|
|
|
self.start()
|
|
|
|
|
|
def run(self):
|
|
|
+ self._update_module_infos()
|
|
|
while True:
|
|
|
block_indices = self._choose_blocks()
|
|
|
self.module_container = ModuleContainer.create(
|
|
@@ -178,27 +178,34 @@ class Server(threading.Thread):
|
|
|
# TODO: Follow ModuleContainer status (to restart/stop if it crashes)
|
|
|
if self.stop.wait(timeout):
|
|
|
return
|
|
|
+
|
|
|
+ self._update_module_infos()
|
|
|
if self._should_choose_other_blocks():
|
|
|
logger.info("Network is imbalanced, server will load other blocks")
|
|
|
break # Stop serving this set of modules
|
|
|
finally:
|
|
|
self.module_container.shutdown()
|
|
|
|
|
|
- def _choose_blocks(self) -> List[int]:
|
|
|
- if self.block_indices is not None:
|
|
|
- return self.block_indices
|
|
|
+ def _update_module_infos(self) -> None:
|
|
|
+ if self.strict_block_indices:
|
|
|
+ return # No need for self._module_infos in this case
|
|
|
|
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
|
|
|
|
|
- assert self.num_blocks is not None
|
|
|
uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
- module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
|
- return choose_best_blocks(self.num_blocks, module_infos)
|
|
|
+ self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
|
+
|
|
|
+ def _choose_blocks(self) -> List[int]:
|
|
|
+ if self.strict_block_indices:
|
|
|
+ return self.strict_block_indices
|
|
|
+
|
|
|
+ assert self.num_blocks is not None
|
|
|
+ return choose_best_blocks(self.num_blocks, self._module_infos)
|
|
|
|
|
|
def _should_choose_other_blocks(self) -> bool:
|
|
|
- if not self.allow_rebalancing:
|
|
|
+ if self.strict_block_indices:
|
|
|
return False
|
|
|
|
|
|
# TODO: Implement actual algorithm here
|