Ver código fonte

black + isort

justheuristic 3 anos atrás
pai
commit
ca21935c77

+ 2 - 1
cli/convert_model.py

@@ -10,8 +10,9 @@ from huggingface_hub import Repository
 from tqdm.auto import tqdm
 from tqdm.auto import tqdm
 
 
 from src import BloomModel
 from src import BloomModel
+from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
 from src.client import DistributedBloomConfig
 from src.client import DistributedBloomConfig
-from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
+
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 

Diferenças do arquivo suprimidas por serem muito extensas
+ 242 - 283
cli/speed_test.py


+ 9 - 2
src/bloom/block.py

@@ -9,8 +9,15 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 import torch.nn.quantized.dynamic.modules.linear
 
 
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
+from src.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 
 
 class BloomAttention(nn.Module):
 class BloomAttention(nn.Module):

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.sequence_manager import RemoteSequenceManager
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
+from src.client.sequence_manager import RemoteSequenceManager

+ 4 - 2
src/client/remote_model.py

@@ -22,11 +22,12 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU 
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
 
 
 
 
 class DistributedBloomModel(BloomModel):
 class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
     """BloomModel, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):
@@ -45,7 +46,7 @@ class DistributedBloomModel(BloomModel):
         )
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix)
         self.h = RemoteSequential(config, dht, config.dht_prefix)
-    
+
         # Forbid accumulate grads for embeddings and layernorm
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
         self.set_requires_grad(False)
 
 
@@ -56,6 +57,7 @@ class DistributedBloomModel(BloomModel):
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):

+ 6 - 6
src/client/sequence_manager.py

@@ -3,10 +3,10 @@ from __future__ import annotations
 import threading
 import threading
 from typing import List, Optional, Sequence, Tuple
 from typing import List, Optional, Sequence, Tuple
 
 
-from hivemind import DHT, PeerID
+from hivemind import DHT
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState, RemoteSpanInfo
+from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
@@ -64,15 +64,15 @@ class RemoteSequenceManager:
                 if server.state != ServerState.ONLINE:
                 if server.state != ServerState.ONLINE:
                     continue
                     continue
                 if peer_id not in active_spans:
                 if peer_id not in active_spans:
-                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                 else:  # peer_id in active_spans
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
 
 
             for peer_id in list(active_spans.keys()):
             for peer_id in list(active_spans.keys()):
                 if (
                 if (
-                    peer_id not in info.servers or
-                    info.servers[peer_id].state != ServerState.ONLINE or
-                    block_index == len(block_infos) - 1
+                    peer_id not in info.servers
+                    or info.servers[peer_id].state != ServerState.ONLINE
+                    or block_index == len(block_infos) - 1
                 ):
                 ):
                     closed_spans.append(active_spans.pop(peer_id))
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
         assert not active_spans

+ 4 - 3
src/data_structures.py

@@ -23,15 +23,16 @@ class ServerInfo:
 
 
 @dataclass
 @dataclass
 class RemoteModuleInfo:
 class RemoteModuleInfo:
-    """ A remote module that is served by one or more servers """
+    """A remote module that is served by one or more servers"""
+
     uid: ModuleUID
     uid: ModuleUID
     servers: Dict[PeerID, ServerInfo]
     servers: Dict[PeerID, ServerInfo]
 
 
 
 
 @dataclass
 @dataclass
 class RemoteSpanInfo:
 class RemoteSpanInfo:
-    """ A chain of remote blocks served by one specific remote peer """
+    """A chain of remote blocks served by one specific remote peer"""
+
     start: int
     start: int
     end: int
     end: int
     peer_id: PeerID
     peer_id: PeerID
-

+ 7 - 9
src/server/server.py

@@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src import declare_active_modules, BloomConfig
+from src import BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
@@ -98,7 +98,7 @@ class Server(threading.Thread):
         cls,
         cls,
         prefix: Optional[str],
         prefix: Optional[str],
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
-        throughput: Union[float, Literal['auto', 'eval']],
+        throughput: Union[float, Literal["auto", "eval"]],
         num_blocks: Optional[int] = None,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
         num_handlers: Optional[int] = None,
@@ -140,17 +140,15 @@ class Server(threading.Thread):
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         memory_cache = MemoryCache(device, cache_size_bytes)
         memory_cache = MemoryCache(device, cache_size_bytes)
 
 
-        assert isinstance(throughput, float) or throughput in ['auto', 'eval']
-        if throughput in ['auto', 'eval']:
-            throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
+        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+        if throughput in ["auto", "eval"]:
+            throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
 
 
         if isinstance(torch_dtype, str):
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
-        )
+        block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
 
 
         if block_indices is not None:
         if block_indices is not None:
             try:
             try:
@@ -288,7 +286,7 @@ class ModuleAnnouncerThread(threading.Thread):
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
         expiration: float,
         expiration: float,
-        **kwargs
+        **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
         self.module_backends = module_backends
         self.module_backends = module_backends

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff