justheuristic 3 år sedan
förälder
incheckning
6e3db6bed6

+ 1 - 1
cli/run_server.py

@@ -68,7 +68,7 @@ def main():
     compression = getattr(CompressionType, compression_type)
 
     use_auth_token = args.pop("use_auth_token")
-    args['use_auth_token'] = True if use_auth_token in ('True', 'true', '') else use_auth_token
+    args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
 
     server = Server.create(**args, start=True, compression=compression)
 

+ 1 - 1
src/bloom/from_pretrained.py

@@ -34,7 +34,7 @@ def load_pretrained_block(
     block_index: int,
     config: Optional[DistributedBloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
-    use_auth_token: Optional[str]=None
+    use_auth_token: Optional[str] = None,
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:

+ 1 - 1
src/server/cache.py

@@ -8,7 +8,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
-from typing import Dict, Optional, Union, AsyncContextManager
+from typing import AsyncContextManager, Dict, Optional, Union
 
 import hivemind
 import torch

+ 13 - 15
src/server/handler.py

@@ -14,6 +14,7 @@ from src.server.backend import MAX_LENGTH, TransformerBackend
 
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
+
     module_backends: Dict[ModuleUID, TransformerBackend]
 
     def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
@@ -42,18 +43,23 @@ class TransformerConnectionHandler(ConnectionHandler):
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
                         cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
-                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3, \
-                            f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+                        assert (
+                            len(hidden_states) == 1 and hidden_states[0].ndim == 3
+                        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
                         hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
                         assert isinstance(hidden_states, (list, tuple))
                         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
 
                     # serialize and send last layer outputs
-                    yield runtime_pb2.ExpertResponse(tensors=[
-                        serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-                        for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
-                    ])
+                    yield runtime_pb2.ExpertResponse(
+                        tensors=[
+                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                            for result, proto in zip(
+                                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+                            )
+                        ]
+                    )
 
                     # prepare for next step
                     prefix_length += hidden_states[0].shape[1]
@@ -63,7 +69,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""
-        uids = (request.uid or '').split(CHAIN_DELIMITER)
+        uids = (request.uid or "").split(CHAIN_DELIMITER)
         if not uids:
             raise RuntimeError("User did not provide any uids")
         for uid in uids:
@@ -86,11 +92,3 @@ class TransformerConnectionHandler(ConnectionHandler):
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
 
             yield handles
-
-
-
-
-
-
-
-

+ 1 - 1
src/server/server.py

@@ -143,7 +143,7 @@ class Server(threading.Thread):
                 block_index,
                 block_config,
                 torch_dtype=torch_dtype,
-                use_auth_token=use_auth_token
+                use_auth_token=use_auth_token,
             )
             for param in block.parameters():
                 param.requires_grad = False

+ 2 - 2
tests/test_chained_inference.py

@@ -34,7 +34,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
     assert isinstance(remote_block, RemoteTransformerBlock)
 
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo('bloom6b3.3 bloom6b3.4', remote_block._info.peer_id)
+    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
 
     inputs = torch.randn(1, 8, 4096)
 
@@ -46,7 +46,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
 
     ref_blocks = [
         load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32)
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
     ]
     outputs_ref = []
     caches = [None, None]