|
@@ -4,7 +4,7 @@ import multiprocessing as mp
|
|
import random
|
|
import random
|
|
import threading
|
|
import threading
|
|
import time
|
|
import time
|
|
-from typing import Dict, Optional, List, Sequence, Union
|
|
|
|
|
|
+from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
@@ -60,7 +60,7 @@ class Server(threading.Thread):
|
|
prefetch_batches: int = 1,
|
|
prefetch_batches: int = 1,
|
|
sender_threads: int = 1,
|
|
sender_threads: int = 1,
|
|
mean_block_selection_delay: float = 0.5,
|
|
mean_block_selection_delay: float = 0.5,
|
|
- mean_balance_check_period: float = 300,
|
|
|
|
|
|
+ mean_balance_check_period: float = 60, # TODO:
|
|
use_auth_token: Optional[str] = None,
|
|
use_auth_token: Optional[str] = None,
|
|
load_in_8bit: bool = False,
|
|
load_in_8bit: bool = False,
|
|
*,
|
|
*,
|
|
@@ -178,6 +178,7 @@ class Server(threading.Thread):
|
|
if self.stop.wait(timeout):
|
|
if self.stop.wait(timeout):
|
|
return
|
|
return
|
|
if self._should_choose_other_blocks():
|
|
if self._should_choose_other_blocks():
|
|
|
|
+ logger.info("Network is imbalanced, server will load other blocks")
|
|
break # Stop serving this set of modules
|
|
break # Stop serving this set of modules
|
|
finally:
|
|
finally:
|
|
self.module_container.shutdown()
|
|
self.module_container.shutdown()
|
|
@@ -217,7 +218,7 @@ class ModuleContainer(threading.Thread):
|
|
dht: DHT,
|
|
dht: DHT,
|
|
module_backends: Dict[str, TransformerBackend],
|
|
module_backends: Dict[str, TransformerBackend],
|
|
*,
|
|
*,
|
|
- device: torch.device,
|
|
|
|
|
|
+ inference_max_length: int,
|
|
num_connection_handlers: int,
|
|
num_connection_handlers: int,
|
|
throughput: float,
|
|
throughput: float,
|
|
update_period: float,
|
|
update_period: float,
|
|
@@ -230,9 +231,10 @@ class ModuleContainer(threading.Thread):
|
|
self.dht, self.module_backends = dht, module_backends
|
|
self.dht, self.module_backends = dht, module_backends
|
|
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
self.conn_handlers = [
|
|
self.conn_handlers = [
|
|
- TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
|
|
|
|
|
|
+ TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
|
|
|
|
+ for _ in range(num_connection_handlers)
|
|
]
|
|
]
|
|
- self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
|
|
|
+ self.runtime = Runtime(self.module_backends, **kwargs)
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
self.module_backends,
|
|
self.module_backends,
|
|
dht,
|
|
dht,
|