Browse Source

Improve block size calculations (#149)

Alexander Borzunov 2 years ago
parent
commit
83d9493b6c
3 changed files with 70 additions and 17 deletions
  1. 48 0
      src/petals/server/block_utils.py
  2. 20 12
      src/petals/server/server.py
  3. 2 5
      src/petals/server/throughput.py

+ 48 - 0
src/petals/server/block_utils.py

@@ -0,0 +1,48 @@
+from typing import Optional, Union
+
+import torch
+from accelerate import init_empty_weights
+
+from petals.bloom import BloomBlock, BloomConfig
+
+
+def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
+    """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
+
+    if dtype == "auto" or dtype is None:
+        dtype = config.torch_dtype
+        if dtype == "auto" or dtype is None:
+            dtype = torch.float32
+    return dtype
+
+
+def get_block_size(
+    config: BloomConfig,
+    location: str,
+    *,
+    dtype: Optional[Union[str, torch.dtype]] = None,
+    load_in_8bit: Optional[bool] = None,
+    layer_index: int = 0,
+    eps: float = 0.01,  # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
+) -> int:
+    if location == "memory":
+        assert (
+            dtype is not None and load_in_8bit is not None
+        ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
+
+    with init_empty_weights():
+        block = BloomBlock(config, layer_index)
+        n_params = sum(param.numel() for param in block.parameters())
+
+    if location == "memory" and load_in_8bit:
+        # Note: We may need a larger eps here for models of size < 1B
+        return n_params * (1 + eps)
+
+    if location == "memory":
+        dtype = resolve_block_dtype(config, dtype)
+    elif location == "disk":
+        dtype = resolve_block_dtype(config, "auto")
+    else:
+        raise ValueError('get_block_size() expects location to be "memory" or "disk"')
+
+    return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))

+ 20 - 12
src/petals/server/server.py

@@ -24,6 +24,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend
+from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.throughput import get_host_throughput
@@ -125,6 +126,11 @@ class Server:
         device = torch.device(device)
         self.device = device
 
+        if isinstance(torch_dtype, str):
+            torch_dtype = DTYPE_MAP[torch_dtype]
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        self.torch_dtype = torch_dtype
+
         if load_in_8bit is None:
             load_in_8bit = device.type == "cuda"
         if load_in_8bit:
@@ -152,11 +158,6 @@ class Server:
         logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
         self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
 
-        if isinstance(torch_dtype, str):
-            torch_dtype = DTYPE_MAP[torch_dtype]
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        self.torch_dtype = torch_dtype
-
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput = get_host_throughput(
@@ -181,19 +182,19 @@ class Server:
         ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
         assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
 
+        total_memory = torch.cuda.get_device_properties(self.device).total_memory
+        block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
         gib = 1024**3
-        total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib
-        block_size_gib = 176 / 70 + 0.5
-        if not self.load_in_8bit:
-            block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4
-        num_blocks = math.floor((total_memory_gib - 2) / block_size_gib)
+        attn_cache_per_block = 0.5 * gib  # TODO: This does not account for manually set --attn_cache_size
+
+        num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
         logger.info(
             f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
             f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
         )
-        return num_blocks
+        return min(num_blocks, self.block_config.n_layer)
 
     def run(self):
         while True:
@@ -231,10 +232,13 @@ class Server:
 
                 while True:
                     timeout = random.random() * 2 * self.mean_balance_check_period
-                    # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
                     if self.stop.wait(timeout):
                         return
 
+                    if not self.module_container.handlers_alive:
+                        logger.warning("One of connection handlers crashed, restarting the server")
+                        break
+
                     if self._should_choose_other_blocks():
                         logger.info("Swarm is imbalanced, server will load other blocks")
                         break  # Stop serving this set of modules
@@ -466,6 +470,10 @@ class ModuleContainer(threading.Thread):
         """
         return self.runtime.ready  # mp.Event that is true if self is ready to process batches
 
+    @property
+    def handlers_alive(self) -> bool:
+        return all(handler.is_alive() for handler in self.conn_handlers)
+
     def shutdown(self):
         """
         Gracefully terminate the container, process-safe.

+ 2 - 5
src/petals/server/throughput.py

@@ -13,6 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from petals.bloom.block import BloomBlock
 from petals.bloom.model import BloomConfig
 from petals.bloom.ops import build_alibi_tensor
+from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_8bit import replace_8bit_linear
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
@@ -29,11 +30,7 @@ def get_host_throughput(
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
 ) -> float:
-    # Resolve default dtypes
-    if dtype == "auto" or dtype is None:
-        dtype = config.torch_dtype
-        if dtype == "auto" or dtype is None:
-            dtype = torch.float32
+    dtype = resolve_block_dtype(config, dtype)
 
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR