Browse Source

black-isort

justheuristic 3 năm trước cách đây
mục cha
commit
14e316b52a

+ 0 - 1
cli/quantize_cpu_naive.py

@@ -51,4 +51,3 @@ if __name__ == "__main__":
 
     model.transformer.h = torch.nn.ModuleList()
     torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))
-

+ 1 - 0
cli/run_server.py

@@ -1,4 +1,5 @@
 import os, sys
+
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))  # add path to src
 
 import configargparse

+ 1 - 1
src/__init__.py

@@ -1 +1 @@
-from .bloom import *
+from .bloom import *

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig
+from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig

+ 2 - 1
src/bloom/block.py

@@ -15,7 +15,8 @@ from src.bloom.ops import (
     attention_mask_func,
     dropout_add,
     pre_process_alibi_for_pad,
-    split_tensor_along_last_dim, build_alibi_tensor,
+    split_tensor_along_last_dim,
+    build_alibi_tensor,
 )
 
 

+ 1 - 1
src/client/__init__.py

@@ -1 +1 @@
-from src.client.remote_block import RemoteTransformerBlock
+from src.client.remote_block import RemoteTransformerBlock

+ 2 - 3
src/client/remote_block.py

@@ -37,9 +37,11 @@ def create_remote_module(
     infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
 ) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
     if return_future:
+
         async def _unpack(infos_future: MPFuture, dht: DHT):
             p2p = await dht.replicate_p2p()
             return _create_remote_experts(await infos_future, p2p)
+
         return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
     p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
     return _create_remote_experts(infos, p2p)
@@ -53,6 +55,3 @@ def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> L
         else:
             experts.append(None)
     return experts
-
-
-

+ 1 - 1
src/server/backend.py

@@ -19,6 +19,7 @@ from src.server.cache import MemoryCache
 
 class BloomBlockBackend(ExpertBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+
     def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
         super().__init__(*args, **kwargs)  # to bypass super.__init__
         self.memory_cache = memory_cache
@@ -31,4 +32,3 @@ class BloomBlockBackend(ExpertBackend):
     def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
         with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
             raise NotImplementedError("TODO")
-

+ 6 - 4
src/server/cache.py

@@ -32,7 +32,7 @@ class MemoryCache:
     """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
 
     def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
-        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
+        self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.device = device
         self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
         self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
@@ -77,12 +77,14 @@ class MemoryCache:
         try:
             async with hivemind.utils.enter_asynchronously(self.lock_metadata):
                 if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
-                    raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
-                                           f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated.")
+                    raise AllocationFailed(
+                        f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
+                        f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
+                    )
 
                 allocated_handle = int(self.handle_counter)
                 self.current_size_bytes += allocated_size_bytes
-                self.handle_counter += 1   # note: this will eventually overflow and it is okay
+                self.handle_counter += 1  # note: this will eventually overflow and it is okay
                 self._pending_messages.value += 1
                 self._pipe_send.send((allocated_handle, descr))
 

+ 31 - 23
src/server/server.py

@@ -23,15 +23,24 @@ logger = get_logger(__file__)
 
 class Server(threading.Thread):
     """Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
+
     def __init__(
-            self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
-            device: torch.device, num_connection_handlers: int = 8,
-            update_period: float = 30, expiration: Optional[float] = None,
-            start: bool, **kwargs
+        self,
+        dht: DHT,
+        module_backends: Dict[str, BloomBlockBackend],
+        *,
+        device: torch.device,
+        num_connection_handlers: int = 8,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
+        start: bool,
+        **kwargs,
     ):
         threading.Thread.__init__(self)
         self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
-        self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
+        self.conn_handlers = [
+            TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
+        ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
         self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
@@ -71,23 +80,23 @@ class Server(threading.Thread):
     # noinspection PyMethodOverriding
     @classmethod
     def create(
-            cls,
-            num_blocks: int,
-            block_config: str,
-            num_handlers: Optional[int] = None,
-            min_batch_size: int = 1,
-            max_batch_size: int = 4096,
-            cache_size_bytes: Optional[int] = None,
-            device: Union[str, torch.device] = None,
-            initial_peers: Sequence[str] = (),
-            compression=CompressionType.NONE,
-            stats_report_interval: Optional[int] = None,
-            custom_module_path=None,
-            update_period: float = 30,
-            expiration: Optional[float] = None,
-            *,
-            start: bool,
-            **kwargs,
+        cls,
+        num_blocks: int,
+        block_config: str,
+        num_handlers: Optional[int] = None,
+        min_batch_size: int = 1,
+        max_batch_size: int = 4096,
+        cache_size_bytes: Optional[int] = None,
+        device: Union[str, torch.device] = None,
+        initial_peers: Sequence[str] = (),
+        compression=CompressionType.NONE,
+        stats_report_interval: Optional[int] = None,
+        custom_module_path=None,
+        update_period: float = 30,
+        expiration: Optional[float] = None,
+        *,
+        start: bool,
+        **kwargs,
     ) -> Server:
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         if custom_module_path is not None:
@@ -181,4 +190,3 @@ class Server(threading.Thread):
 
         self.runtime.shutdown()
         logger.info("Server shutdown succesfully")
-