|
@@ -22,6 +22,7 @@ from src.server.block_selection import choose_best_blocks
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
from src.server.throughput import get_host_throughput
|
|
|
+from src.utils.convert_8bit import replace_8bit_linear
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -35,7 +36,6 @@ class Server(threading.Thread):
|
|
|
dht: DHT,
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
*,
|
|
|
- device: torch.device,
|
|
|
num_connection_handlers: int = 8,
|
|
|
throughput: float,
|
|
|
update_period: float = 30,
|
|
@@ -49,7 +49,7 @@ class Server(threading.Thread):
|
|
|
self.conn_handlers = [
|
|
|
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
]
|
|
|
- self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
+ self.runtime = Runtime(self.module_backends, **kwargs)
|
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
self.module_backends,
|
|
|
dht,
|
|
@@ -101,10 +101,12 @@ class Server(threading.Thread):
|
|
|
throughput: Union[float, str],
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
- num_handlers: Optional[int] = None,
|
|
|
+ num_handlers: int = 8,
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: int = 4096,
|
|
|
torch_dtype: str = "auto",
|
|
|
+ revision: str = "main",
|
|
|
+ cache_dir: Optional[str] = None,
|
|
|
cache_size_bytes: Optional[int] = None,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
initial_peers: Sequence[str] = (),
|
|
@@ -115,6 +117,7 @@ class Server(threading.Thread):
|
|
|
expiration: Optional[float] = None,
|
|
|
max_block_selection_delay: float = 1,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
+ load_in_8bit: bool = False,
|
|
|
*,
|
|
|
start: bool,
|
|
|
**kwargs,
|
|
@@ -148,7 +151,9 @@ class Server(threading.Thread):
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
- block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
+ block_config = BloomConfig.from_pretrained(
|
|
|
+ converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
|
|
|
+ )
|
|
|
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
@@ -186,7 +191,15 @@ class Server(threading.Thread):
|
|
|
block_config,
|
|
|
torch_dtype=torch_dtype,
|
|
|
use_auth_token=use_auth_token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
)
|
|
|
+
|
|
|
+ if load_in_8bit:
|
|
|
+ dtype = block.input_layernorm.weight.dtype
|
|
|
+ assert dtype == torch.float16, f"'load_in_8bit' does not support {dtype} for now"
|
|
|
+ block = replace_8bit_linear(block)
|
|
|
+
|
|
|
+ block = block.to(device)
|
|
|
for param in block.parameters():
|
|
|
param.requires_grad = False
|
|
|
|