|
@@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-class Server(threading.Thread):
|
|
|
+class Server:
|
|
|
"""
|
|
|
Runs ModuleContainer, periodically checks that the network is balanced,
|
|
|
restarts the ModuleContainer with other layers if the imbalance is significant
|
|
@@ -68,13 +68,10 @@ class Server(threading.Thread):
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
load_in_8bit: bool = False,
|
|
|
- start: bool,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
|
- super().__init__()
|
|
|
-
|
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
self.num_handlers = num_handlers
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
@@ -147,8 +144,6 @@ class Server(threading.Thread):
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
- if start:
|
|
|
- self.start()
|
|
|
|
|
|
def run(self):
|
|
|
while True:
|
|
@@ -231,6 +226,118 @@ class Server(threading.Thread):
|
|
|
class ModuleContainer(threading.Thread):
|
|
|
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
|
|
|
|
|
|
+ # noinspection PyMethodOverriding
|
|
|
+ @classmethod
|
|
|
+ def create(
|
|
|
+ cls,
|
|
|
+ *,
|
|
|
+ dht: DHT,
|
|
|
+ prefix: str,
|
|
|
+ converted_model_name_or_path: str,
|
|
|
+ block_config: BloomConfig,
|
|
|
+ memory_cache: MemoryCache,
|
|
|
+ throughput: float,
|
|
|
+ block_indices: List[int],
|
|
|
+ num_handlers: Optional[int],
|
|
|
+ min_batch_size: int,
|
|
|
+ max_batch_size: int,
|
|
|
+ inference_max_length: int,
|
|
|
+ torch_dtype: torch.dtype,
|
|
|
+ cache_dir: Optional[str],
|
|
|
+ device: Union[str, torch.device],
|
|
|
+ compression: CompressionType,
|
|
|
+ stats_report_interval: Optional[int],
|
|
|
+ update_period: float,
|
|
|
+ expiration: Optional[float],
|
|
|
+ prefetch_batches: int,
|
|
|
+ sender_threads: int,
|
|
|
+ use_auth_token: Optional[str],
|
|
|
+ load_in_8bit: bool,
|
|
|
+ start: bool,
|
|
|
+ ) -> ModuleContainer:
|
|
|
+ module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
|
+ joining_announcer = ModuleAnnouncerThread(
|
|
|
+ module_uids,
|
|
|
+ dht,
|
|
|
+ ServerState.JOINING,
|
|
|
+ throughput=throughput,
|
|
|
+ update_period=update_period,
|
|
|
+ expiration=expiration,
|
|
|
+ daemon=True,
|
|
|
+ )
|
|
|
+ joining_announcer.start()
|
|
|
+ logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
+
|
|
|
+ try:
|
|
|
+ blocks = {}
|
|
|
+ for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
+ block = load_pretrained_block(
|
|
|
+ converted_model_name_or_path,
|
|
|
+ block_index,
|
|
|
+ block_config,
|
|
|
+ torch_dtype=torch_dtype,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ )
|
|
|
+
|
|
|
+ if load_in_8bit:
|
|
|
+ dtype = block.input_layernorm.weight.dtype
|
|
|
+ block = replace_8bit_linear(block)
|
|
|
+
|
|
|
+ block = block.to(device)
|
|
|
+ for param in block.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+
|
|
|
+ blocks[module_uid] = TransformerBackend(
|
|
|
+ module_uid,
|
|
|
+ block,
|
|
|
+ memory_cache=memory_cache,
|
|
|
+ backend_dtype=None if torch_dtype == "auto" else torch_dtype,
|
|
|
+ args_schema=(
|
|
|
+ BatchTensorDescriptor(
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ kwargs_schema={},
|
|
|
+ outputs_schema=(
|
|
|
+ BatchTensorDescriptor(
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ min_batch_size=min_batch_size,
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ joining_announcer.stop.set()
|
|
|
+ joining_announcer.join()
|
|
|
+ declare_active_modules(
|
|
|
+ dht,
|
|
|
+ module_uids,
|
|
|
+ expiration_time=get_dht_time() + expiration,
|
|
|
+ state=ServerState.OFFLINE,
|
|
|
+ throughput=throughput,
|
|
|
+ )
|
|
|
+ logger.info(f"Announced that blocks {module_uids} are offline")
|
|
|
+ raise
|
|
|
+ else:
|
|
|
+ joining_announcer.stop.set()
|
|
|
+ joining_announcer.join()
|
|
|
+
|
|
|
+ return cls(
|
|
|
+ dht,
|
|
|
+ blocks,
|
|
|
+ throughput=throughput,
|
|
|
+ num_connection_handlers=num_handlers,
|
|
|
+ inference_max_length=inference_max_length,
|
|
|
+ device=device,
|
|
|
+ stats_report_interval=stats_report_interval,
|
|
|
+ update_period=update_period,
|
|
|
+ expiration=expiration,
|
|
|
+ prefetch_batches=prefetch_batches,
|
|
|
+ sender_threads=sender_threads,
|
|
|
+ start=start,
|
|
|
+ )
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
dht: DHT,
|
|
@@ -253,9 +360,10 @@ class ModuleContainer(threading.Thread):
|
|
|
for _ in range(num_connection_handlers)
|
|
|
]
|
|
|
self.runtime = Runtime(self.module_backends, **kwargs)
|
|
|
- self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
- self.module_backends,
|
|
|
+ self.online_announcer = ModuleAnnouncerThread(
|
|
|
+ list(self.module_backends.keys()),
|
|
|
dht,
|
|
|
+ ServerState.ONLINE,
|
|
|
throughput=throughput,
|
|
|
update_period=update_period,
|
|
|
expiration=expiration,
|
|
@@ -279,8 +387,7 @@ class ModuleContainer(threading.Thread):
|
|
|
if not self.dht.is_alive():
|
|
|
self.dht.run_in_background(await_ready=True)
|
|
|
|
|
|
- if self.module_backends:
|
|
|
- self.dht_handler_thread.start()
|
|
|
+ self.online_announcer.start()
|
|
|
|
|
|
if self.checkpoint_saver is not None:
|
|
|
self.checkpoint_saver.start()
|
|
@@ -290,99 +397,6 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
self.runtime.run()
|
|
|
|
|
|
- # noinspection PyMethodOverriding
|
|
|
- @classmethod
|
|
|
- def create(
|
|
|
- cls,
|
|
|
- *,
|
|
|
- dht: DHT,
|
|
|
- prefix: str,
|
|
|
- converted_model_name_or_path: str,
|
|
|
- block_config: BloomConfig,
|
|
|
- memory_cache: MemoryCache,
|
|
|
- throughput: float,
|
|
|
- block_indices: List[int],
|
|
|
- num_handlers: Optional[int],
|
|
|
- min_batch_size: int,
|
|
|
- max_batch_size: int,
|
|
|
- inference_max_length: int,
|
|
|
- torch_dtype: torch.dtype,
|
|
|
- cache_dir: Optional[str],
|
|
|
- device: Union[str, torch.device],
|
|
|
- compression: CompressionType,
|
|
|
- stats_report_interval: Optional[int],
|
|
|
- update_period: float,
|
|
|
- expiration: Optional[float],
|
|
|
- prefetch_batches: int,
|
|
|
- sender_threads: int,
|
|
|
- use_auth_token: Optional[str],
|
|
|
- load_in_8bit: bool,
|
|
|
- start: bool,
|
|
|
- ) -> ModuleContainer:
|
|
|
- module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
|
- declare_active_modules(
|
|
|
- dht,
|
|
|
- module_uids,
|
|
|
- expiration_time=get_dht_time() + expiration,
|
|
|
- state=ServerState.JOINING,
|
|
|
- throughput=throughput,
|
|
|
- )
|
|
|
- logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
-
|
|
|
- blocks = {}
|
|
|
- for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
- block = load_pretrained_block(
|
|
|
- converted_model_name_or_path,
|
|
|
- block_index,
|
|
|
- block_config,
|
|
|
- torch_dtype=torch_dtype,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- )
|
|
|
-
|
|
|
- if load_in_8bit:
|
|
|
- dtype = block.input_layernorm.weight.dtype
|
|
|
- block = replace_8bit_linear(block)
|
|
|
-
|
|
|
- block = block.to(device)
|
|
|
- for param in block.parameters():
|
|
|
- param.requires_grad = False
|
|
|
-
|
|
|
- blocks[module_uid] = TransformerBackend(
|
|
|
- module_uid,
|
|
|
- block,
|
|
|
- memory_cache=memory_cache,
|
|
|
- backend_dtype=None if torch_dtype == "auto" else torch_dtype,
|
|
|
- args_schema=(
|
|
|
- BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
- ),
|
|
|
- ),
|
|
|
- kwargs_schema={},
|
|
|
- outputs_schema=(
|
|
|
- BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
|
|
- ),
|
|
|
- ),
|
|
|
- min_batch_size=min_batch_size,
|
|
|
- max_batch_size=max_batch_size,
|
|
|
- )
|
|
|
-
|
|
|
- return cls(
|
|
|
- dht,
|
|
|
- blocks,
|
|
|
- throughput=throughput,
|
|
|
- num_connection_handlers=num_handlers,
|
|
|
- inference_max_length=inference_max_length,
|
|
|
- device=device,
|
|
|
- stats_report_interval=stats_report_interval,
|
|
|
- update_period=update_period,
|
|
|
- expiration=expiration,
|
|
|
- prefetch_batches=prefetch_batches,
|
|
|
- sender_threads=sender_threads,
|
|
|
- start=start,
|
|
|
- )
|
|
|
-
|
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
|
"""
|
|
|
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
|
|
@@ -411,18 +425,17 @@ class ModuleContainer(threading.Thread):
|
|
|
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
|
"""
|
|
|
- if self.module_backends:
|
|
|
- self.dht_handler_thread.stop.set()
|
|
|
- self.dht_handler_thread.join()
|
|
|
+ self.online_announcer.stop.set()
|
|
|
+ self.online_announcer.join()
|
|
|
|
|
|
- declare_active_modules(
|
|
|
- self.dht,
|
|
|
- self.module_backends.keys(),
|
|
|
- expiration_time=get_dht_time() + self.expiration,
|
|
|
- state=ServerState.OFFLINE,
|
|
|
- throughput=self.throughput,
|
|
|
- )
|
|
|
- logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
+ declare_active_modules(
|
|
|
+ self.dht,
|
|
|
+ self.module_backends.keys(),
|
|
|
+ expiration_time=get_dht_time() + self.expiration,
|
|
|
+ state=ServerState.OFFLINE,
|
|
|
+ throughput=self.throughput,
|
|
|
+ )
|
|
|
+ logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
|
|
|
self.ready.clear()
|
|
|
|
|
@@ -450,8 +463,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- module_backends: Dict[str, TransformerBackend],
|
|
|
+ module_uids: List[str],
|
|
|
dht: DHT,
|
|
|
+ state: ServerState,
|
|
|
*,
|
|
|
throughput: float,
|
|
|
update_period: float = 30,
|
|
@@ -459,8 +473,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
- self.module_backends = module_backends
|
|
|
+ self.module_uids = module_uids
|
|
|
self.dht = dht
|
|
|
+ self.state = state
|
|
|
self.throughput = throughput
|
|
|
self.update_period = update_period
|
|
|
self.expiration = expiration
|
|
@@ -470,9 +485,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
while True:
|
|
|
declare_active_modules(
|
|
|
self.dht,
|
|
|
- self.module_backends.keys(),
|
|
|
+ self.module_uids,
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
- state=ServerState.ONLINE,
|
|
|
+ state=self.state,
|
|
|
throughput=self.throughput,
|
|
|
)
|
|
|
if self.stop.wait(self.update_period):
|