123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- import threading
- from typing import Optional, Dict, Union, Sequence
- import torch
- from hivemind import Server, DHT
- from hivemind.moe.server.dht_handler import DHTHandlerThread
- from hivemind.moe.server.layers import add_custom_models_from_file
- from hivemind.moe.server.runtime import Runtime
- from hivemind.proto.runtime_pb2 import CompressionType
- from hivemind.utils.logging import use_hivemind_log_handler, get_logger
- from src import DistributedBloomConfig
- from src.bloom.block import BloomBlock
- from src.server.cache import MemoryCache
- from src.server.backend import BloomBlockBackend
- from src.server.handler import BloomConnectionHandler
- use_hivemind_log_handler("in_root_logger")
- logger = get_logger(__file__)
- class BloomServer(Server):
- """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
- def __init__(
- self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
- device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
- cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
- ):
- threading.Thread.__init__(self)
- self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
- self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
- self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
- self.runtime = Runtime(self.module_backends, device=device, **kwargs)
- self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
- self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
- if start:
- self.run_in_background(await_ready=True)
- # noinspection PyMethodOverriding
- @classmethod
- def create(
- cls,
- num_blocks: int,
- block_config: str,
- num_handlers: Optional[int] = None,
- min_batch_size: int = 1,
- max_batch_size: int = 4096,
- cache_size_bytes: Optional[int] = None,
- device: Union[str, torch.device] = None,
- initial_peers: Sequence[str] = (),
- compression=CompressionType.NONE,
- stats_report_interval: Optional[int] = None,
- custom_module_path=None,
- update_period: float = 30,
- expiration: Optional[float] = None,
- *,
- start: bool,
- **kwargs,
- ) -> Server:
- """Create a server with one or more bloom blocks. See run_server.py for documentation."""
- if custom_module_path is not None:
- add_custom_models_from_file(custom_module_path)
- dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
- visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
- logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
- num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
- device = device or ("cuda" if torch.cuda.is_available() else "cpu")
- if isinstance(block_config, str):
- block_config = DistributedBloomConfig
- # initialize modules
- module_backends = {}
- for i in range(len(module_backends)):
- module_uid = f"dummy_block.{i}"
- block = BloomBlock(block_config, layer_number=i)
- #TODO run the actual model
- module_backends[module_uid] = BloomBlockBackend(
- name=expert_uid,
- expert=block,
- args_schema=args_schema,
- num_warmup_steps=num_warmup_steps,
- num_total_steps=num_total_steps,
- clip_grad_norm=clip_grad_norm,
- min_batch_size=min_batch_size,
- max_batch_size=max_batch_size,
- )
- if checkpoint_dir is not None:
- load_experts(experts, checkpoint_dir)
- return cls(
- dht,
- experts,
- cache_size_bytes=cache_size_bytes,
- num_connection_handlers=num_handlers,
- device=device,
- checkpoint_dir=checkpoint_dir,
- stats_report_interval=stats_report_interval,
- update_period=update_period,
- expiration=expiration,
- start=start,
- )
|