justheuristic vor 3 Jahren
Ursprung
Commit
e32208c954

+ 2 - 9
src/bloom/block.py

@@ -9,15 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    build_alibi_tensor,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-)
+from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
+                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
 
 
 class BloomAttention(nn.Module):

+ 2 - 5
src/bloom/model.py

@@ -9,11 +9,8 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
+from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
+                                     add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 1 - 2
src/client/remote_block.py

@@ -11,13 +11,12 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
+from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
 
 from src.data_structures import RemoteModuleInfo
 from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 4 - 6
src/client/remote_model.py

@@ -1,20 +1,18 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Union, Tuple
+from typing import Optional, Tuple, Union
 
 import hivemind
+import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
 
 from src.bloom import BloomForYou, DistributedBloomConfig
 from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
-import torch
-from hivemind import use_hivemind_log_handler
-from torch.nn import CrossEntropyLoss
-from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 2 - 2
src/client/remote_sequence_info.py

@@ -3,10 +3,10 @@ from __future__ import annotations
 import dataclasses
 import threading
 from functools import partial
-from typing import Tuple, List, Optional, Sequence, NamedTuple
+from typing import List, NamedTuple, Optional, Sequence, Tuple
 
 from hivemind import DHT, PeerID
-from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src.data_structures import ModuleUID, RemoteModuleInfo
 from src.dht_utils import _get_remote_module_infos

+ 0 - 1
src/client/remote_sequential.py

@@ -15,7 +15,6 @@ from src.client.remote_sequence_info import RemoteSequenceInfo
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 2 - 2
src/server/handler.py

@@ -5,11 +5,11 @@ from typing import AsyncIterator, Dict, Sequence
 import torch
 from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
 from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
+from hivemind.utils import as_aiter
 from hivemind.utils.asyncio import anext
 from hivemind.utils.streaming import split_for_streaming
-from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
-from hivemind.utils import as_aiter
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend

+ 1 - 1
src/server/server.py

@@ -14,7 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
-from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler

+ 1 - 1
tests/test_full_model.py

@@ -3,7 +3,7 @@ import os
 
 import torch
 import transformers
-from hivemind import use_hivemind_log_handler, get_logger
+from hivemind import get_logger, use_hivemind_log_handler
 
 from src.client.remote_model import DistributedBloomForCausalLM