|
@@ -34,6 +34,7 @@ class Server(threading.Thread):
|
|
|
*,
|
|
|
device: torch.device,
|
|
|
num_connection_handlers: int = 8,
|
|
|
+ throughput: float,
|
|
|
update_period: float = 30,
|
|
|
expiration: Optional[float] = None,
|
|
|
start: bool,
|
|
@@ -46,7 +47,12 @@ class Server(threading.Thread):
|
|
|
]
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
|
- self.module_backends, dht, update_period, expiration, daemon=True
|
|
|
+ self.module_backends,
|
|
|
+ dht,
|
|
|
+ throughput=throughput,
|
|
|
+ update_period=update_period,
|
|
|
+ expiration=expiration,
|
|
|
+ daemon=True,
|
|
|
)
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
|
|
@@ -88,6 +94,7 @@ class Server(threading.Thread):
|
|
|
cls,
|
|
|
prefix: Optional[str],
|
|
|
converted_model_name_or_path: str,
|
|
|
+ throughput: float,
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
num_handlers: Optional[int] = None,
|
|
@@ -177,6 +184,7 @@ class Server(threading.Thread):
|
|
|
return cls(
|
|
|
dht,
|
|
|
blocks,
|
|
|
+ throughput=throughput,
|
|
|
num_connection_handlers=num_handlers,
|
|
|
device=device,
|
|
|
stats_report_interval=stats_report_interval,
|
|
@@ -241,18 +249,31 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
|
|
|
|
|
|
def __init__(
|
|
|
- self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
|
|
|
+ self,
|
|
|
+ module_backends: Dict[str, TransformerBackend],
|
|
|
+ dht: DHT,
|
|
|
+ throughput: float,
|
|
|
+ update_period: float = 30,
|
|
|
+ expiration: Optional[int] = None,
|
|
|
+ **kwargs
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
if expiration is None:
|
|
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
|
self.module_backends = module_backends
|
|
|
self.dht = dht
|
|
|
+ self.throughput = throughput
|
|
|
self.update_period = update_period
|
|
|
self.expiration = expiration
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
def run(self) -> None:
|
|
|
- declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|
|
|
- while not self.stop.wait(self.update_period):
|
|
|
- declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
|
|
|
+ while True:
|
|
|
+ declare_active_modules(
|
|
|
+ self.dht,
|
|
|
+ self.module_backends.keys(),
|
|
|
+ expiration_time=get_dht_time() + self.expiration,
|
|
|
+ throughput=self.throughput,
|
|
|
+ )
|
|
|
+ if self.stop.wait(self.update_period):
|
|
|
+ break
|