justheuristic hace 3 años
padre
commit
471e47c0f5
Se han modificado 4 ficheros con 30 adiciones y 36 borrados
  1. 2 9
      src/bloom/block.py
  2. 2 5
      src/bloom/model.py
  3. 4 11
      src/client/remote_model.py
  4. 22 11
      src/client/remote_sequential.py

+ 2 - 9
src/bloom/block.py

@@ -9,15 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    build_alibi_tensor,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-)
+from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
+                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
 
 
 class BloomAttention(nn.Module):

+ 2 - 5
src/bloom/model.py

@@ -11,11 +11,8 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
+from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
+                                     add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 4 - 11
src/client/remote_model.py

@@ -1,29 +1,22 @@
 # this code is in active development, interfaces may change
+from typing import Optional
 
-from hivemind import DHT, use_hivemind_log_handler, get_logger
+from hivemind import DHT, get_logger, use_hivemind_log_handler
 
-from src.bloom import DistributedBloomConfig, BloomForCausalLM
+from src.bloom import BloomForCausalLM, DistributedBloomConfig
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
-        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
-        if prefix.endswith(UID_DELIMITER):
-            logger.warning(f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
-                           f"This will cause {self.__class__.__name__} to look for modules under "
-                           f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended.")
 
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None):
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
         assert len(self.transformer.h) == 0
         config.n_layer = n_layer
         self.transformer.h = RemoteSequential(config, dht, prefix)
-
-

+ 22 - 11
src/client/remote_sequential.py

@@ -1,16 +1,15 @@
 import logging
 from functools import partial
-from typing import Tuple
+from typing import Optional, Tuple
 
 import torch
+from hivemind import DHT, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
 from src import DistributedBloomConfig
 from src.data_structures import UID_DELIMITER, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos, _create_remote_modules_from_infos
-from hivemind import DHT, use_hivemind_log_handler, get_logger
-
+from src.dht_utils import _create_remote_modules_from_infos, _get_remote_module_infos
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -18,7 +17,16 @@ logger = get_logger(__file__)
 
 class RemoteSequential(nn.Sequential):
     """A sequence of transformer blocks hosted by the swarm"""
-    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3):
+        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
+        if prefix.endswith(UID_DELIMITER):
+            logger.warning(
+                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
+                f"This will cause {self.__class__.__name__} to look for modules under "
+                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
+            )
+
         super().__init__()
         self.config = config
         self.dht = dht
@@ -27,9 +35,12 @@ class RemoteSequential(nn.Sequential):
         self.prefix = prefix
         self.block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
         logger.debug(f"Remote block uids: {self.block_uids}")
-        self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float('inf')), return_future=False
-        ))
+        self.block_infos: Tuple[RemoteModuleInfo, ...] = tuple(
+            dht.run_coroutine(
+                partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
+                return_future=False,
+            )
+        )
 
         self.max_retries = max_retries
 
@@ -39,7 +50,7 @@ class RemoteSequential(nn.Sequential):
             assert info is not None, f"Found no active peers for block {uid}"
             assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
             assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
-            assert len(info.peer_ids) > 0,  f"Found no active peers for block {uid}"
+            assert len(info.peer_ids) > 0, f"Found no active peers for block {uid}"
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -47,7 +58,7 @@ class RemoteSequential(nn.Sequential):
             for retry_index in range(self.max_retries):
                 try:
                     block = self[block_index]
-                    outputs, = block(inputs)
+                    (outputs,) = block(inputs)
                     assert isinstance(outputs, torch.Tensor)
                     assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
                     inputs = outputs
@@ -61,7 +72,7 @@ class RemoteSequential(nn.Sequential):
 
     def __getitem__(self, block_index: int):
         assert 0 <= block_index < self.config.n_layer
-        module, = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
+        (module,) = _create_remote_modules_from_infos([self.block_infos[block_index]], self.p2p)
         return module
 
     def __iter__(self):