server.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import threading
  2. from typing import Optional, Dict, Union, Sequence
  3. import torch
  4. from hivemind import Server, DHT
  5. from hivemind.moe.server.dht_handler import DHTHandlerThread
  6. from hivemind.moe.server.layers import add_custom_models_from_file
  7. from hivemind.moe.server.runtime import Runtime
  8. from hivemind.proto.runtime_pb2 import CompressionType
  9. from hivemind.utils.logging import use_hivemind_log_handler, get_logger
  10. from src import DistributedBloomConfig
  11. from src.bloom.block import BloomBlock
  12. from src.server.cache import MemoryCache
  13. from src.server.backend import BloomBlockBackend
  14. from src.server.handler import BloomConnectionHandler
  15. use_hivemind_log_handler("in_root_logger")
  16. logger = get_logger(__file__)
  17. class BloomServer(Server):
  18. """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
  19. def __init__(
  20. self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
  21. device: torch.device, num_connection_handlers: int = 8, update_period: float = 30,
  22. cache_size_bytes: Optional[int] = None, start: bool, **kwargs,
  23. ):
  24. threading.Thread.__init__(self)
  25. self.attention_cache = MemoryCache(device=device, max_size_bytes=cache_size_bytes)
  26. self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
  27. self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
  28. self.runtime = Runtime(self.module_backends, device=device, **kwargs)
  29. self.dht_handler_thread = DHTHandlerThread(self.experts, dht, update_period=update_period, daemon=True)
  30. self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
  31. if start:
  32. self.run_in_background(await_ready=True)
  33. # noinspection PyMethodOverriding
  34. @classmethod
  35. def create(
  36. cls,
  37. num_blocks: int,
  38. block_config: str,
  39. num_handlers: Optional[int] = None,
  40. min_batch_size: int = 1,
  41. max_batch_size: int = 4096,
  42. cache_size_bytes: Optional[int] = None,
  43. device: Union[str, torch.device] = None,
  44. initial_peers: Sequence[str] = (),
  45. compression=CompressionType.NONE,
  46. stats_report_interval: Optional[int] = None,
  47. custom_module_path=None,
  48. update_period: float = 30,
  49. expiration: Optional[float] = None,
  50. *,
  51. start: bool,
  52. **kwargs,
  53. ) -> Server:
  54. """Create a server with one or more bloom blocks. See run_server.py for documentation."""
  55. if custom_module_path is not None:
  56. add_custom_models_from_file(custom_module_path)
  57. dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
  58. visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
  59. logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
  60. num_handlers = num_handlers if num_handlers is not None else num_blocks * 8
  61. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  62. if isinstance(block_config, str):
  63. block_config = DistributedBloomConfig
  64. # initialize modules
  65. module_backends = {}
  66. for i in range(len(module_backends)):
  67. module_uid = f"dummy_block.{i}"
  68. block = BloomBlock(block_config, layer_number=i)
  69. #TODO run the actual model
  70. module_backends[module_uid] = BloomBlockBackend(
  71. name=expert_uid,
  72. expert=block,
  73. args_schema=args_schema,
  74. num_warmup_steps=num_warmup_steps,
  75. num_total_steps=num_total_steps,
  76. clip_grad_norm=clip_grad_norm,
  77. min_batch_size=min_batch_size,
  78. max_batch_size=max_batch_size,
  79. )
  80. if checkpoint_dir is not None:
  81. load_experts(experts, checkpoint_dir)
  82. return cls(
  83. dht,
  84. experts,
  85. cache_size_bytes=cache_size_bytes,
  86. num_connection_handlers=num_handlers,
  87. device=device,
  88. checkpoint_dir=checkpoint_dir,
  89. stats_report_interval=stats_report_interval,
  90. update_period=update_period,
  91. expiration=expiration,
  92. start=start,
  93. )