|
@@ -6,6 +6,7 @@ import threading
|
|
|
import time
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
@@ -17,8 +18,8 @@ from src import BloomConfig, declare_active_modules
|
|
|
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
from src.dht_utils import get_remote_module_infos
|
|
|
+from src.server import block_selection
|
|
|
from src.server.backend import TransformerBackend
|
|
|
-from src.server.block_selection import choose_best_blocks
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
from src.server.throughput import get_host_throughput
|
|
@@ -59,7 +60,8 @@ class Server(threading.Thread):
|
|
|
prefetch_batches: int = 1,
|
|
|
sender_threads: int = 1,
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
- mean_balance_check_period: float = 300,
|
|
|
+ mean_balance_check_period: float = 150,
|
|
|
+ min_balance_quality: float = 0.8,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
load_in_8bit: bool = False,
|
|
|
*,
|
|
@@ -122,6 +124,7 @@ class Server(threading.Thread):
|
|
|
use_auth_token=use_auth_token,
|
|
|
revision=revision,
|
|
|
)
|
|
|
+ self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
if block_indices is not None:
|
|
@@ -135,14 +138,13 @@ class Server(threading.Thread):
|
|
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
- self._module_infos = None
|
|
|
+ self.min_balance_quality = min_balance_quality
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
if start:
|
|
|
self.start()
|
|
|
|
|
|
def run(self):
|
|
|
- self._update_module_infos()
|
|
|
while True:
|
|
|
block_indices = self._choose_blocks()
|
|
|
self.module_container = ModuleContainer.create(
|
|
@@ -179,37 +181,29 @@ class Server(threading.Thread):
|
|
|
if self.stop.wait(timeout):
|
|
|
return
|
|
|
|
|
|
- self._update_module_infos()
|
|
|
if self._should_choose_other_blocks():
|
|
|
- logger.info("Network is imbalanced, server will load other blocks")
|
|
|
+ logger.info("Swarm is imbalanced, server will load other blocks")
|
|
|
break # Stop serving this set of modules
|
|
|
finally:
|
|
|
self.module_container.shutdown()
|
|
|
|
|
|
- def _update_module_infos(self) -> None:
|
|
|
- if self.strict_block_indices:
|
|
|
- return # No need for self._module_infos in this case
|
|
|
+ def _choose_blocks(self) -> List[int]:
|
|
|
+ if self.strict_block_indices is not None:
|
|
|
+ return self.strict_block_indices
|
|
|
+ assert self.num_blocks is not None
|
|
|
|
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
|
|
-
|
|
|
- uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
- self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
|
|
|
-
|
|
|
- def _choose_blocks(self) -> List[int]:
|
|
|
- if self.strict_block_indices:
|
|
|
- return self.strict_block_indices
|
|
|
-
|
|
|
- assert self.num_blocks is not None
|
|
|
- return choose_best_blocks(self.num_blocks, self._module_infos)
|
|
|
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
|
|
+ return block_selection.choose_best_blocks(self.num_blocks, module_infos)
|
|
|
|
|
|
def _should_choose_other_blocks(self) -> bool:
|
|
|
- if self.strict_block_indices:
|
|
|
+ if self.strict_block_indices is not None:
|
|
|
return False
|
|
|
|
|
|
- # TODO: Implement actual algorithm here
|
|
|
- return True
|
|
|
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
|
|
+ return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
|
|
|
|
|
|
def shutdown(self):
|
|
|
self.stop.set()
|