Kaynağa Gözat

black everything

justheuristic 3 yıl önce
ebeveyn
işleme
e8241d2915

+ 2 - 2
cli/inference_one_block.py

@@ -25,7 +25,7 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
     parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
     parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
@@ -35,7 +35,7 @@ if __name__ == '__main__':
     args = parser.parse_args()
 
     if args.device is None:
-        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        args.device = "cuda" if torch.cuda.is_available() else "cpu"
 
     config = DistributedBloomConfig.from_json_file(args.config)
     block = BloomBlock(config, args.layer_index).to(args.device)

+ 9 - 2
src/bloom/block.py

@@ -9,8 +9,15 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
+from src.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 class BloomAttention(nn.Module):

+ 5 - 2
src/bloom/model.py

@@ -11,8 +11,11 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 2 - 1
src/client/inference_chain.py

@@ -22,7 +22,8 @@ class RemoteInferenceChain(nn.Module):
     def step(self, hidden_states: torch.Tensor):
         pass
 
+
 # plan:
 # - run inference STUB from a jupyter notebook
 # - extend to run actual inference
-# - extend to run multiple layers at a time
+# - extend to run multiple layers at a time

+ 19 - 11
src/client/remote_block.py

@@ -31,6 +31,7 @@ class RemoteTransformerBlock(RemoteExpert):
 
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
+
     def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
         self.uid, self.info = uid, info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
@@ -41,7 +42,7 @@ class RemoteTransformerBlockInferenceSession:
 
     @classmethod
     async def _create(
-            cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
+        cls, remote_module: RemoteTransformerBlock, timeout: Optional[float] = None
     ) -> RemoteTransformerBlockInferenceSession:
         """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
         inputs_queue = asyncio.Queue()
@@ -64,12 +65,17 @@ class RemoteTransformerBlockInferenceSession:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
         inputs = (new_hidden_states,)
-        outputs_serialized = RemoteExpertWorker.run_coroutine(self._step(
-            runtime_pb2.ExpertRequest(uid=self.uid, tensors=[
-                serialize_torch_tensor(tensor, proto.compression)
-                for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
-            ])
-        ))
+        outputs_serialized = RemoteExpertWorker.run_coroutine(
+            self._step(
+                runtime_pb2.ExpertRequest(
+                    uid=self.uid,
+                    tensors=[
+                        serialize_torch_tensor(tensor, proto.compression)
+                        for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
+                    ],
+                )
+            )
+        )
         outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
         assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
         return outputs[0]
@@ -119,10 +125,11 @@ def get_remote_module(
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time),
-        return_future)
+        partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time), return_future
+    )
 
     if return_future:
+
         async def _unpack(infos_future: MPFuture, dht: DHT):
             p2p = await dht.replicate_p2p()
             return _create_remote_modules_from_infos(await infos_future, p2p)
@@ -148,8 +155,9 @@ async def _get_remote_module_infos(
     return experts
 
 
-def _create_remote_modules_from_infos(infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-                                      ) -> List[Optional[RemoteTransformerBlock]]:
+def _create_remote_modules_from_infos(
+    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
+) -> List[Optional[RemoteTransformerBlock]]:
     experts: List[Optional[RemoteTransformerBlock]] = []
     for info in infos:
         if info is not None:

+ 9 - 6
src/server/backend.py

@@ -28,11 +28,14 @@ class TransformerBackend(ModuleBackend):
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         attention_cache_handle = int(cache_metadata[0, 0].item())
         prefix_length = int(cache_metadata[0, 1].item())
-        hidden_states, *_ = inputs  # todo: this ignores any extra inputs for now; in future, it would be best to support attention mask as an extra input
+        (
+            hidden_states,
+            *_,
+        ) = inputs  # todo: this ignores any extra inputs for now; in future, it would be best to support attention mask as an extra input
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
-            print('METADATA:', cache_metadata)
+            print("METADATA:", cache_metadata)
             assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
             layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
             print(past_k.shape, past_v.shape)
@@ -44,10 +47,10 @@ class TransformerBackend(ModuleBackend):
             assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
             assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
             assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-            assert torch.allclose(new_v[:, :past_v.shape[1]], past_v)
-            assert torch.allclose(new_k[:, :past_k.shape[1]], past_k)
-            cache[0, :, prefix_length: new_length, :] = new_k[:, prefix_length: new_length]
-            cache[1, :, prefix_length: new_length, :] = new_v[:, prefix_length: new_length]
+            assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
+            assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
+            cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+            cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
             return (hidden_states,)
 
     def get_pools(self) -> Sequence[TaskPool]:

+ 1 - 1
src/server/handler.py

@@ -46,6 +46,6 @@ class TransformerConnectionHandler(ConnectionHandler):
                     yield runtime_pb2.ExpertResponse(tensors=outputs)
 
                     prefix_length += inputs[1].shape[1]
-                    request = await(anext(requests))
+                    request = await (anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")

+ 2 - 3
src/server/server.py

@@ -114,7 +114,7 @@ class Server(threading.Thread):
 
         if block_indices is not None:
             try:
-                start, end = block_indices.split(':')
+                start, end = block_indices.split(":")
                 start, end = map(int, map(str.strip, (start, end)))
             except Exception as e:
                 logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
@@ -122,14 +122,13 @@ class Server(threading.Thread):
             block_indices = range(start, end)
         else:
             assert num_blocks is not None
-            block_indices = range(num_blocks) # TODO replace with proper load balancing
+            block_indices = range(num_blocks)  # TODO replace with proper load balancing
 
         ## TODO: the code below will load the entire model in RAM. Please replace with sliced model
         block_config = DistributedBloomConfig.from_pretrained(block_config, use_auth_token=True)
         # model = BloomForCausalLM.from_pretrained(model, use_auth_token=True)
         ## /TODO
 
-
         # initialize modules
         blocks = {}
         for block_index in block_indices: