5
0
Эх сурвалжийг харах

Fix bug with model duplication in RAM

dbaranchuk 3 жил өмнө
parent
commit
ffeceebd4b
1 өөрчлөгдсөн 3 нэмэгдсэн , 3 устгасан
  1. 3 3
      src/server/server.py

+ 3 - 3
src/server/server.py

@@ -36,7 +36,6 @@ class Server(threading.Thread):
         dht: DHT,
         module_backends: Dict[str, TransformerBackend],
         *,
-        device: torch.device,
         num_connection_handlers: int = 8,
         throughput: float,
         update_period: float = 30,
@@ -50,7 +49,7 @@ class Server(threading.Thread):
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
-        self.runtime = Runtime(self.module_backends, device=device, **kwargs)
+        self.runtime = Runtime(self.module_backends, **kwargs)
         self.dht_handler_thread = ModuleAnnouncerThread(
             self.module_backends,
             dht,
@@ -102,7 +101,7 @@ class Server(threading.Thread):
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
-        num_handlers: Optional[int] = None,
+        num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
         torch_dtype: str = "auto",
@@ -197,6 +196,7 @@ class Server(threading.Thread):
             if load_in_8bit:
                 block = replace_8bit_linear(block)
 
+            block = block.to(device)
             for param in block.parameters():
                 param.requires_grad = False