Browse Source

black-isort

justheuristic 3 years ago
parent
commit
d03b38b9eb
5 changed files with 22 additions and 10 deletions
  1. 2 0
      src/__init__.py
  2. 9 2
      src/bloom/block.py
  3. 5 2
      src/bloom/model.py
  4. 3 3
      src/data_structures.py
  5. 3 3
      src/dht_utils.py

+ 2 - 0
src/__init__.py

@@ -1,3 +1,5 @@
 from .bloom import *
 from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from .dht_utils import declare_active_modules, get_remote_module
+
+__version__ = "0.1"

+ 9 - 2
src/bloom/block.py

@@ -9,8 +9,15 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
+from src.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 class BloomAttention(nn.Module):

+ 5 - 2
src/bloom/model.py

@@ -11,8 +11,11 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 3 - 3
src/data_structures.py

@@ -1,8 +1,8 @@
-from typing import NamedTuple, Collection
+from typing import Collection, NamedTuple
 
 from hivemind import PeerID
 
-
 ModuleUID = str
-UID_DELIMITER = '.'
+UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
+CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
 RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

+ 3 - 3
src/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
-from src.data_structures import UID_DELIMITER, ModuleUID, RemoteModuleInfo
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -39,7 +39,7 @@ def declare_active_modules(
     if not isinstance(uids, list):
         uids = list(uids)
     for uid in uids:
-        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
+        assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
     return dht.run_coroutine(
         partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
         return_future=not wait,
@@ -68,7 +68,7 @@ def get_remote_module(
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
     expiration_time: Optional[DHTExpiration] = None,
     return_future: bool = False,
-) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
+) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)