瀏覽代碼

black-isort

justheuristic 3 年之前
父節點
當前提交
83cd4412a1
共有 8 個文件被更改,包括 24 次插入32 次删除
  1. 1 2
      cli/convert_model.py
  2. 1 1
      src/bloom/__init__.py
  3. 2 9
      src/bloom/block.py
  4. 14 10
      src/bloom/from_pretrained.py
  5. 2 5
      src/bloom/model.py
  6. 1 1
      src/bloom/ops.py
  7. 1 1
      src/server/backend.py
  8. 2 3
      src/server/server.py

+ 1 - 2
cli/convert_model.py

@@ -3,10 +3,10 @@ import os
 
 import psutil
 import torch.backends.quantized
+import torch.nn as nn
 import transformers
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
-import torch.nn as nn
 from tqdm.auto import tqdm
 
 use_hivemind_log_handler("in_root_logger")
@@ -85,4 +85,3 @@ if __name__ == "__main__":
         config.save_pretrained(".")
 
     logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
-

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig, BloomBlock
+from src.bloom.model import BloomBlock, BloomForCausalLM, BloomModel, DistributedBloomConfig

+ 2 - 9
src/bloom/block.py

@@ -9,15 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    build_alibi_tensor,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-)
+from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
+                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
 
 
 class BloomAttention(nn.Module):

+ 14 - 10
src/bloom/from_pretrained.py

@@ -11,18 +11,18 @@ from __future__ import annotations
 from typing import Optional, OrderedDict, Union
 
 import torch
-from hivemind.utils.logging import use_hivemind_log_handler, get_logger
-from transformers.utils.hub import hf_bucket_url, cached_path
-
-from src.bloom import BloomForCausalLM, DistributedBloomConfig, BloomBlock
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
+from transformers.utils.hub import cached_path, hf_bucket_url
+
+from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 CLIENT_BRANCH = "client"
 BLOCK_BRANCH_PREFIX = "block_"
-USER_AGENT = {'file_type': 'model', 'framework': 'pytorch', 'from_auto_class': False}
+USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
 cls = BloomForCausalLM
 FORCE_DOWNLOAD = False
 RESUME_DOWNLOAD = False
@@ -30,8 +30,11 @@ LOCAL_FILES_ONLY = False
 
 
 def load_pretrained_block(
-        converted_model_name_or_path: str, block_index: int,
-        config: Optional[DistributedBloomConfig] = None, torch_dtype: Union[torch.dtype, str] = 'auto') -> BloomBlock:
+    converted_model_name_or_path: str,
+    block_index: int,
+    config: Optional[DistributedBloomConfig] = None,
+    torch_dtype: Union[torch.dtype, str] = "auto",
+) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
@@ -39,7 +42,7 @@ def load_pretrained_block(
     state_dict = _load_state_dict(converted_model_name_or_path, block_index)
     block.load_state_dict(state_dict)
 
-    if torch_dtype == 'auto':
+    if torch_dtype == "auto":
         with torch.no_grad():
             for name, param in block.named_parameters():
                 assert name in state_dict, f"{name} not in state dict"
@@ -54,7 +57,8 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-        pretrained_model_name_or_path: str, block_index: Optional[int] = None) -> OrderedDict[str, torch.Tensor]:
+    pretrained_model_name_or_path: str, block_index: Optional[int] = None
+) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
 
@@ -69,7 +73,7 @@ def _load_state_dict(
         use_auth_token=True,
         user_agent=USER_AGENT,
     )
-    state_dict = torch.load(resolved_archive_file, map_location='cpu')
+    state_dict = torch.load(resolved_archive_file, map_location="cpu")
     return state_dict
 
 

+ 2 - 5
src/bloom/model.py

@@ -11,11 +11,8 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
+from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
+                                     add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 1 - 1
src/bloom/ops.py

@@ -7,8 +7,8 @@ import math
 
 import torch
 import torch.autograd
-from torch import nn
 import torch.nn.functional as F
+from torch import nn
 
 
 def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):

+ 1 - 1
src/server/backend.py

@@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend):
             print("METADATA:", cache_metadata)
             assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
             layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-            print('PAST', past_k.shape, past_v.shape)
+            print("PAST", past_k.shape, past_v.shape)
             hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 
             # todo remove these asserts once we pass all tests

+ 2 - 3
src/server/server.py

@@ -12,12 +12,11 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.bloom.from_pretrained import load_pretrained_block, DistributedBloomConfig, DTYPE_MAP
+from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
@@ -89,7 +88,7 @@ class Server(threading.Thread):
         num_handlers: Optional[int] = None,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
-        torch_dtype: str = 'auto',
+        torch_dtype: str = "auto",
         cache_size_bytes: Optional[int] = None,
         device: Union[str, torch.device] = None,
         initial_peers: Sequence[str] = (),