Răsfoiți Sursa

black-isort

justheuristic 3 ani în urmă
părinte
comite
4ad845bce3

+ 1 - 1
cli/convert_model.py

@@ -48,7 +48,7 @@ if __name__ == "__main__":
     config = transformers.AutoConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
-    model = transformers.AutoModel.from_pretrained(    
+    model = transformers.AutoModel.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
     tokenizer = transformers.AutoTokenizer.from_pretrained(

+ 9 - 2
src/bloom/block.py

@@ -9,8 +9,15 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
+from src.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 class BloomAttention(nn.Module):

+ 14 - 11
src/bloom/model.py

@@ -9,8 +9,11 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
@@ -153,9 +156,9 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
 
         # Embedding + LN Embedding
-        
+
         # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  # dtype=config.torch_dtype
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
         # Transformer blocks
@@ -177,10 +180,10 @@ class BloomModel(BloomPreTrainedModel):
 
     def set_input_embeddings(self, new_embeddings):
         self.word_embeddings = new_embeddings
-    
+
     def set_requires_grad(self, value):
         for p in self.parameters():
-            p.requires_grad=value
+            p.requires_grad = value
 
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
@@ -320,9 +323,9 @@ class BloomForYou(BloomPreTrainedModel):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):
-         super().__init__(config)
-         self.transformer = BloomModel(config)
-         self.lm_head = None 
+        super().__init__(config)
+        self.transformer = BloomModel(config)
+        self.lm_head = None
 
-         # Initialize weights and apply final processing
-         self.post_init()
+        # Initialize weights and apply final processing
+        self.post_init()

+ 13 - 13
src/client/remote_model.py

@@ -31,29 +31,31 @@ class DistributedBloomForYou(BloomForYou):
 
     @classmethod
     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
-        if 'initial_peers' not in kwargs:
+        if "initial_peers" not in kwargs:
             raise ValueError("Please specify initial_peers=...")
-        
+
         dht = hivemind.DHT(
-            initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
-            start=True)
+            initial_peers=kwargs.pop("initial_peers"), client_mode=kwargs.pop("client_mode", True), start=True
+        )
 
-        if 'prefix' not in kwargs:
+        if "prefix" not in kwargs:
             logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
-            assert UID_DELIMITER not in pretrained_model_name_or_path, \
-                f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
+            assert (
+                UID_DELIMITER not in pretrained_model_name_or_path
+            ), f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
         prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
 
         config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
         model = cls(config, dht, prefix)
-        model.transformer.load_state_dict(_load_state_dict(
-            pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
-        ), strict=True) 
+        model.transformer.load_state_dict(
+            _load_state_dict(pretrained_model_name_or_path, use_auth_token=kwargs.get("use_auth_token")), strict=True
+        )
         return model
 
 
 class DistributedBloomForCausalLM(DistributedBloomForYou):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
         # only last token for inputs_ids if past is defined in kwargs
         if past:
@@ -86,9 +88,7 @@ class DistributedBloomForCausalLM(DistributedBloomForYou):
             are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
         """
         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-        transformer_outputs = self.transformer.forward(
-            input_ids=input_ids, return_dict=return_dict, **kwargs
-        )
+        transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
 
         # Switch dtype in case word_embeddings are fp16
         word_embeddings = self.transformer.word_embeddings.weight.t()

+ 4 - 3
src/client/remote_sequence_info.py

@@ -15,12 +15,13 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
+Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
 @dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
 class RemoteSequenceInfo:
     """Keeps and updates the meta-information about which peers host which blocks"""
+
     dht: DHT
     block_uids: List[ModuleUID, ...]
     block_infos: List[Optional[RemoteModuleInfo], ...]
@@ -48,8 +49,8 @@ class RemoteSequenceInfo:
 
     def update_block_infos_(self):
         new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
-            return_future=False)
+            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
+        )
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:

+ 2 - 2
src/client/remote_sequential.py

@@ -103,8 +103,8 @@ class RemoteSequentialInferenceSession:
 
             # TODO begin throwaway prototype code
             remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
-            _=remote.info #TODO fix
-            span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
+            _ = remote.info  # TODO fix
+            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
             remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
             self.active_sessions.append(remote.inference_session())
             self.stack.enter_context(self.active_sessions[-1])

+ 6 - 2
src/server/backend.py

@@ -30,13 +30,17 @@ class TransformerBackend(ModuleBackend):
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
             hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
-            assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+            assert (
+                hidden_states.ndim == 3
+            ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
-                hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+                hidden_states, (new_k, new_v) = self.module.forward(
+                    hidden_states, layer_past=layer_past, use_cache=True
+                )
 
                 # todo remove these asserts once we pass all tests
                 new_length = new_v.shape[1]

+ 32 - 44
src/server/handler.py

@@ -76,22 +76,22 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Run a chain of requested backends 
+        # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
             assert (
                 len(hidden_states) == 1 and hidden_states[0].ndim == 3
             ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
             hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-        
+
         # Serialize the overall output and respond
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
-        return runtime_pb2.ExpertResponse(tensors=[
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip(
-                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
-            )
-        ])
+        return runtime_pb2.ExpertResponse(
+            tensors=[
+                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+            ]
+        )
 
     async def rpc_forward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -101,48 +101,41 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_header_str(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        # Run a chain of requested backends 
+        # Run a chain of requested backends
         for backend in requested_backends:
             assert isinstance(hidden_states, (list, tuple))
             assert (
                 len(hidden_states) == 1 and hidden_states[0].ndim == 3
             ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
             hidden_states = await backend.forward_pool.submit_task(*hidden_states)
-        
+
         # Serialize the overall output
         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
             serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip(
-                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
-            )
+            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
         # Split the serialized_output for streaming and respond
         output_split = [
-            part
-            for tensor in serialized_output
-            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
         async for part in as_aiter(*output_split):
             yield runtime_pb2.ExpertResponse(tensors=[part])
 
-    async def rpc_backward(
-        self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
+    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse requests and prepare backends
         inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_header(request)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
         # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its output 
+        # Note that we do not forward for the last module since we do not need its output
         inter_inputs = [inputs]
         for backend in requested_backends[:-1]:
-            assert (inputs.ndim == 3
-            ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
             inputs = await backend.forward_pool.submit_task(inputs)
-            assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
+            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
             inputs = inputs[0]
             inter_inputs.append(inputs)
 
@@ -150,16 +143,16 @@ class TransformerConnectionHandler(ConnectionHandler):
         for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
             inputs_and_grads = [inp, grads]
             grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
+            assert isinstance(grads, (list, tuple)) and len(grads) == 1
             grads = grads[0]
-        
+
         # Serialize the overall grad_input and respond
-        return runtime_pb2.ExpertResponse(tensors=[
-            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip(
-                [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
-            )
-        ])
+        return runtime_pb2.ExpertResponse(
+            tensors=[
+                serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
+            ]
+        )
 
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -170,35 +163,30 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
         # Run a forward chain to collect intermediate inputs
-        # Note that we do not forward for the last module since we do not need its outputs 
+        # Note that we do not forward for the last module since we do not need its outputs
         inter_inputs = [inputs]
         for backend in requested_backends[:-1]:
-            assert (inputs.ndim == 3
-            ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+            assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
             inputs = await backend.forward_pool.submit_task(inputs)
-            assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
+            assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
             inputs = inputs[0]
             inter_inputs.append(inputs)
 
-         # Run a backward chain for requested backends
+        # Run a backward chain for requested backends
         for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
             inputs_and_grads = [inp, grads]
             grads = await backend.backward_pool.submit_task(*inputs_and_grads)
-            assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
+            assert isinstance(grads, (list, tuple)) and len(grads) == 1
             grads = grads[0]
-        
+
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
             serialize_torch_tensor(result, proto.compression, allow_inplace=True)
-            for result, proto in zip(
-                [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
-            )
+            for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema))
         ]
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
-            part
-            for tensor in serialized_grad_inputs
-            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+            part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
 
         async for part in as_aiter(*output_split):

+ 6 - 3
src/server/server.py

@@ -111,9 +111,10 @@ class Server(threading.Thread):
             add_custom_models_from_file(custom_module_path)
         if prefix is None:
             prefix = converted_model_name_or_path
-            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
-                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
+            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
                 f"Please specify --prefix manually when starting a server"
+            )
             logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
@@ -139,7 +140,9 @@ class Server(threading.Thread):
             assert num_blocks is not None
             block_indices = range(num_blocks)  # TODO replace with proper load balancing
 
-        block_config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
+        block_config = DistributedBloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token
+        )
 
         # initialize modules
         blocks = {}

+ 3 - 3
tests/test_chained_forward_backward.py

@@ -30,10 +30,10 @@ REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
 # seq_length <= 128: rpc_forward & rpc_backward
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    remote_block, = get_remote_module(dht, BLOCK_UID)
+    (remote_block,) = get_remote_module(dht, BLOCK_UID)
     assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
     assert isinstance(remote_block, RemoteTransformerBlock)
-    
+
     _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
     remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
 
@@ -41,7 +41,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
         load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
         load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
         load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
-    ]        
+    ]
     inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
     outputs_rpc = remote_block.forward(inputs)[0]
     outputs_rpc.sum().backward()

+ 2 - 2
tests/test_full_model.py

@@ -29,7 +29,7 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
     assert len(model.transformer.h) == model.config.n_layer
 
-    test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
+    test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
     parallel_outputs = model.forward(test_inputs).logits
     assert torch.all(torch.isfinite(parallel_outputs))
     logger.info("Forward outputs are finite")
@@ -49,7 +49,7 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
     recurrent_outputs = []
     with model.transformer.h.inference_session() as sess:
         for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
+            recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
     recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
     recurrent_outputs = model.transformer.ln_f(recurrent_outputs)