justheuristic 3 yıl önce
ebeveyn
işleme
3b9351de1c

+ 3 - 4
cli/inference_one_block.py

@@ -1,13 +1,12 @@
 import argparse
 
 import torch
-from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from tqdm.auto import trange
 
-from src.bloom.model import DistributedBloomConfig
 from src.bloom.block import BloomBlock
+from src.bloom.model import DistributedBloomConfig
 from src.bloom.ops import build_alibi_tensor
-from tqdm.auto import trange
-
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)

+ 0 - 1
cli/run_server.py

@@ -1,5 +1,4 @@
 import configargparse
-
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler

+ 1 - 1
src/bloom/__init__.py

@@ -1 +1 @@
-from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig
+from src.bloom.model import BloomForCausalLM, BloomModel, DistributedBloomConfig

+ 2 - 9
src/bloom/block.py

@@ -9,15 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-    build_alibi_tensor,
-)
+from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
+                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
 
 
 class BloomAttention(nn.Module):

+ 2 - 5
src/bloom/model.py

@@ -11,11 +11,8 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
+from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
+                                     add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig

+ 1 - 1
src/client/inference_chain.py

@@ -1,5 +1,5 @@
-from typing import Sequence
 from collections import defaultdict
+from typing import Sequence
 
 import torch
 from hivemind import DHT

+ 7 - 5
src/client/remote_block.py

@@ -1,16 +1,18 @@
 from __future__ import annotations
+
 import asyncio
 from functools import partial
-from typing import List, Optional, Union, Sequence, AsyncIterator, Dict, Any
+from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Union
 
 import torch
+from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertUID, ExpertInfo as RemoteModuleInfo
+from hivemind.moe.expert_uid import ExpertInfo as RemoteModuleInfo
+from hivemind.moe.expert_uid import ExpertUID
 from hivemind.p2p import P2P, PeerID, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.utils import MPFuture, DHTExpiration, get_dht_time, as_aiter, anext, nested_flatten
-from hivemind.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import DHTExpiration, MPFuture, anext, as_aiter, get_dht_time, nested_flatten
 
 from src.server.handler import TransformerConnectionHandler
 

+ 1 - 1
src/server/backend.py

@@ -1,5 +1,5 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Tuple, Sequence
+from typing import Sequence, Tuple
 
 import torch
 from hivemind.moe.server.module_backend import ModuleBackend

+ 2 - 2
src/server/handler.py

@@ -1,12 +1,12 @@
 from typing import AsyncIterator, Dict
 
 import torch
-from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor, nested_flatten
+from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
 
-from src.server.backend import TransformerBackend, MAX_LENGTH
+from src.server.backend import MAX_LENGTH, TransformerBackend
 
 
 class TransformerConnectionHandler(ConnectionHandler):

+ 6 - 5
src/server/server.py

@@ -1,6 +1,8 @@
 from __future__ import annotations
+
+import multiprocessing as mp
 import threading
-from typing import Optional, Dict, Union, Sequence
+from typing import Dict, Optional, Sequence, Union
 
 import torch
 from hivemind import DHT, BatchTensorDescriptor
@@ -8,13 +10,12 @@ from hivemind.moe.server.dht_handler import DHTHandlerThread
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.logging import use_hivemind_log_handler, get_logger
-import multiprocessing as mp
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import DistributedBloomConfig, BloomForCausalLM
+from src import BloomForCausalLM, DistributedBloomConfig
 from src.bloom.block import BloomBlock
-from src.server.cache import MemoryCache
 from src.server.backend import TransformerBackend
+from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")