Aleksandr Borzunov 2 ani în urmă
părinte
comite
b5d54f42c0
1 a modificat fișierele cu 7 adăugiri și 5 ștergeri
  1. 7 5
      src/server/server.py

+ 7 - 5
src/server/server.py

@@ -4,7 +4,7 @@ import multiprocessing as mp
 import random
 import threading
 import time
-from typing import Dict, Optional, List, Sequence, Union
+from typing import Dict, List, Optional, Sequence, Union
 
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@@ -60,7 +60,7 @@ class Server(threading.Thread):
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         mean_block_selection_delay: float = 0.5,
-        mean_balance_check_period: float = 300,
+        mean_balance_check_period: float = 60,  # TODO:
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -178,6 +178,7 @@ class Server(threading.Thread):
                     if self.stop.wait(timeout):
                         return
                     if self._should_choose_other_blocks():
+                        logger.info("Network is imbalanced, server will load other blocks")
                         break  # Stop serving this set of modules
             finally:
                 self.module_container.shutdown()
@@ -217,7 +218,7 @@ class ModuleContainer(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
-        device: torch.device,
+        inference_max_length: int,
         num_connection_handlers: int,
         throughput: float,
         update_period: float,
@@ -230,9 +231,10 @@ class ModuleContainer(threading.Thread):
         self.dht, self.module_backends = dht, module_backends
         self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
-            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+            TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
+            for _ in range(num_connection_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
             self.module_backends,
             dht,