|
@@ -24,6 +24,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
|
from petals.server import block_selection
|
|
|
from petals.server.backend import TransformerBackend
|
|
|
+from petals.server.block_utils import get_block_size
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.throughput import get_host_throughput
|
|
@@ -125,6 +126,11 @@ class Server:
|
|
|
device = torch.device(device)
|
|
|
self.device = device
|
|
|
|
|
|
+ if isinstance(torch_dtype, str):
|
|
|
+ torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
+ self.torch_dtype = torch_dtype
|
|
|
+
|
|
|
if load_in_8bit is None:
|
|
|
load_in_8bit = device.type == "cuda"
|
|
|
if load_in_8bit:
|
|
@@ -152,11 +158,6 @@ class Server:
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
|
|
|
- if isinstance(torch_dtype, str):
|
|
|
- torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
- assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
- self.torch_dtype = torch_dtype
|
|
|
-
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
if throughput in ["auto", "eval"]:
|
|
|
throughput = get_host_throughput(
|
|
@@ -181,19 +182,19 @@ class Server:
|
|
|
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
|
|
|
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
|
|
|
|
|
|
+ total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
+ block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
|
|
|
gib = 1024**3
|
|
|
- total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib
|
|
|
- block_size_gib = 176 / 70 + 0.5
|
|
|
- if not self.load_in_8bit:
|
|
|
- block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4
|
|
|
- num_blocks = math.floor((total_memory_gib - 2) / block_size_gib)
|
|
|
+ attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size
|
|
|
+
|
|
|
+ num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
|
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
|
|
logger.info(
|
|
|
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
|
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
|
|
)
|
|
|
- return num_blocks
|
|
|
+ return min(num_blocks, self.block_config.n_layer)
|
|
|
|
|
|
def run(self):
|
|
|
while True:
|
|
@@ -231,10 +232,13 @@ class Server:
|
|
|
|
|
|
while True:
|
|
|
timeout = random.random() * 2 * self.mean_balance_check_period
|
|
|
- # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
|
|
|
if self.stop.wait(timeout):
|
|
|
return
|
|
|
|
|
|
+ if not self.module_container.handlers_alive:
|
|
|
+ logger.warning("One of connection handlers crashed, restarting the server")
|
|
|
+ break
|
|
|
+
|
|
|
if self._should_choose_other_blocks():
|
|
|
logger.info("Swarm is imbalanced, server will load other blocks")
|
|
|
break # Stop serving this set of modules
|
|
@@ -466,6 +470,10 @@ class ModuleContainer(threading.Thread):
|
|
|
"""
|
|
|
return self.runtime.ready # mp.Event that is true if self is ready to process batches
|
|
|
|
|
|
+ @property
|
|
|
+ def handlers_alive(self) -> bool:
|
|
|
+ return all(handler.is_alive() for handler in self.conn_handlers)
|
|
|
+
|
|
|
def shutdown(self):
|
|
|
"""
|
|
|
Gracefully terminate the container, process-safe.
|