|
@@ -9,7 +9,6 @@ import time
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import numpy as np
|
|
|
-import psutil
|
|
|
import torch
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
@@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.reachability import check_reachability
|
|
|
-from petals.server.throughput import get_host_throughput
|
|
|
+from petals.server.throughput import get_dtype_name, get_host_throughput
|
|
|
from petals.utils.convert_block import check_device_balance, convert_block
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
@@ -146,12 +145,6 @@ class Server:
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
- if load_in_8bit is None:
|
|
|
- load_in_8bit = device.type == "cuda"
|
|
|
- if load_in_8bit:
|
|
|
- logger.info("Model weights will be loaded in 8-bit format")
|
|
|
- self.load_in_8bit = load_in_8bit
|
|
|
-
|
|
|
if tensor_parallel_devices is None:
|
|
|
tensor_parallel_devices = (device,)
|
|
|
self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
|
|
@@ -159,6 +152,17 @@ class Server:
|
|
|
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
|
|
|
check_device_balance(self.tensor_parallel_devices)
|
|
|
|
|
|
+ if load_in_8bit is None:
|
|
|
+ load_in_8bit = device.type == "cuda"
|
|
|
+ if load_in_8bit and len(self.tensor_parallel_devices) > 1:
|
|
|
+ load_in_8bit = False
|
|
|
+ logger.warning(
|
|
|
+ "Tensor parallelism doesn't work properly with 8-bit weights yet, loading weights in 16-bit. "
|
|
|
+ "You can explicitly set `--load_in_8bit True` to override this"
|
|
|
+ )
|
|
|
+ self.load_in_8bit = load_in_8bit
|
|
|
+ logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
|
|
|
+
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
|
if num_blocks is None and block_indices is None:
|
|
|
num_blocks = self._choose_num_blocks()
|
|
@@ -167,8 +171,7 @@ class Server:
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
|
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
|
|
except Exception as e:
|
|
|
- logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
- raise
|
|
|
+ raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
num_blocks = len(block_indices)
|
|
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
@@ -301,10 +304,6 @@ class Server:
|
|
|
del self.module_container
|
|
|
gc.collect() # In particular, this closes unused file descriptors
|
|
|
|
|
|
- cur_proc = psutil.Process()
|
|
|
- num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
|
|
|
- logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
|
|
|
-
|
|
|
if self.device.type == "cuda":
|
|
|
torch.cuda.empty_cache()
|
|
|
|