浏览代码

black-isort

justheuristic 3 年之前
父节点
当前提交
d179ec138d
共有 4 个文件被更改,包括 29 次插入14 次删除
  1. 5 5
      cli/run_server.py
  2. 16 5
      src/client/inference_session.py
  3. 1 1
      src/client/remote_generation.py
  4. 7 3
      src/server/handler.py

+ 5 - 5
cli/run_server.py

@@ -12,15 +12,15 @@ import re
 
 
 def parse_size_as_bytes(size: str) -> int:
-    """ parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471 """
-    units = {"B": 1, "KB": 2 ** 10, "MB": 2 ** 20, "GB": 2 ** 30, "TB": 2 ** 40, "PB": 2 ** 50}
+    """parse human-readable data size e.g. 1.5GB, based on https://stackoverflow.com/a/42865957/2002471"""
+    units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40, "PB": 2**50}
     size = size.strip().upper().rstrip("IB ")
     if not size.endswith("B"):
         size += "B"
-    if not re.match(r' ', size):
-        size = re.sub(r'([KMGT]?)', r' \1', size)
+    if not re.match(r" ", size):
+        size = re.sub(r"([KMGT]?)", r" \1", size)
     number, unit = [string.strip() for string in size.split()]
-    return int(float(number)*units[unit])
+    return int(float(number) * units[unit])
 
 
 def main():

+ 16 - 5
src/client/inference_session.py

@@ -7,12 +7,13 @@ from typing import AsyncIterator, List, Optional
 import torch
 from hivemind import (
     P2P,
+    MSGPackSerializer,
     anext,
     deserialize_torch_tensor,
     get_logger,
     nested_flatten,
     serialize_torch_tensor,
-    use_hivemind_log_handler, MSGPackSerializer,
+    use_hivemind_log_handler,
 )
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
@@ -33,8 +34,15 @@ class RemoteTransformerBlockInferenceSession:
     :note: this inference session is *not* fault-tolerant out of the box
     """
 
-    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator,
-                 *, max_length: int):
+    def __init__(
+        self,
+        uid: ModuleUID,
+        rpc_info: RPCInfo,
+        inputs_queue: asyncio.Queue,
+        outputs_aiter: AsyncIterator,
+        *,
+        max_length: int,
+    ):
         self.uid, self.rpc_info = uid, rpc_info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
@@ -75,7 +83,7 @@ class RemoteTransformerBlockInferenceSession:
                         serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
-                    metadata=self._serialized_metadata if not self.stepped else None
+                    metadata=self._serialized_metadata if not self.stepped else None,
                 )
             )
         )
@@ -145,7 +153,10 @@ class RemoteSequentialInferenceSession:
             span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
             inference_session = RemoteExpertWorker.run_coroutine(
                 RemoteTransformerBlockInferenceSession._create(
-                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout,
+                    stub,
+                    span_uids,
+                    rpc_info=self.sequence_manager.rpc_info,
+                    timeout=self.timeout,
                 )
             )
             self.inference_sessions.append(inference_session)

+ 1 - 1
src/client/remote_generation.py

@@ -62,7 +62,7 @@ class RemoteGenerationMixin:
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, "inputs must be a 3d tensor [batch, len, hid]"
-        prefix_length = (0 if inputs is None else inputs.size(1))
+        prefix_length = 0 if inputs is None else inputs.size(1)
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id

+ 7 - 3
src/server/handler.py

@@ -68,8 +68,10 @@ class TransformerConnectionHandler(ConnectionHandler):
                     length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
 
                     if prefix_length + length_increment > max_length:
-                        raise ValueError(f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
-                                         f" exceeds pre-allocated maximum {max_length}")
+                        raise ValueError(
+                            f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
+                            f" exceeds pre-allocated maximum {max_length}"
+                        )
 
                     # Cast inputs to backend dtype
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@@ -207,7 +209,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int) -> Sequence[int]:
+    async def _allocate_caches(
+        self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
+    ) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""
         async with contextlib.AsyncExitStack() as stack:
             handles = []