|
@@ -36,7 +36,6 @@ class Server(threading.Thread):
|
|
|
dht: DHT,
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
*,
|
|
|
- device: torch.device,
|
|
|
num_connection_handlers: int = 8,
|
|
|
throughput: float,
|
|
|
update_period: float = 30,
|
|
@@ -50,7 +49,7 @@ class Server(threading.Thread):
|
|
|
self.conn_handlers = [
|
|
|
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
]
|
|
|
- self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
+ self.runtime = Runtime(self.module_backends, **kwargs)
|
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
self.module_backends,
|
|
|
dht,
|
|
@@ -102,7 +101,7 @@ class Server(threading.Thread):
|
|
|
throughput: Union[float, str],
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
- num_handlers: Optional[int] = None,
|
|
|
+ num_handlers: int = 8,
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: int = 4096,
|
|
|
torch_dtype: str = "auto",
|
|
@@ -197,6 +196,7 @@ class Server(threading.Thread):
|
|
|
if load_in_8bit:
|
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
|
+ block = block.to(device)
|
|
|
for param in block.parameters():
|
|
|
param.requires_grad = False
|
|
|
|