Эх сурвалжийг харах

Reorder imports with isort (#326)

* Apply isort everywhere

* Config isort in CI

* Update pyproject.toml

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Michael Diskin 4 жил өмнө
parent
commit
bedfa6eefb
68 өөрчлөгдсөн 179 нэмэгдсэн , 165 устгасан
  1. 21 0
      .github/workflows/check-style.yml
  2. 0 13
      .github/workflows/check_style.yml
  3. 3 2
      CONTRIBUTING.md
  4. 1 2
      benchmarks/benchmark_tensor_compression.py
  5. 1 2
      docs/conf.py
  6. 1 1
      examples/albert/arguments.py
  7. 3 3
      examples/albert/run_trainer.py
  8. 2 2
      examples/albert/run_training_monitor.py
  9. 0 1
      examples/albert/utils.py
  10. 4 4
      hivemind/__init__.py
  11. 4 4
      hivemind/averaging/allreduce.py
  12. 7 7
      hivemind/averaging/averager.py
  13. 3 3
      hivemind/averaging/key_manager.py
  14. 2 1
      hivemind/averaging/load_balancing.py
  15. 3 3
      hivemind/averaging/matchmaking.py
  16. 3 4
      hivemind/averaging/partition.py
  17. 3 3
      hivemind/averaging/training.py
  18. 0 1
      hivemind/dht/crypto.py
  19. 3 3
      hivemind/dht/node.py
  20. 6 6
      hivemind/dht/protocol.py
  21. 2 1
      hivemind/dht/routing.py
  22. 1 1
      hivemind/dht/storage.py
  23. 1 1
      hivemind/dht/traverse.py
  24. 2 2
      hivemind/hivemind_cli/run_server.py
  25. 1 1
      hivemind/moe/__init__.py
  26. 8 8
      hivemind/moe/client/beam_search.py
  27. 3 3
      hivemind/moe/client/expert.py
  28. 6 6
      hivemind/moe/client/moe.py
  29. 3 3
      hivemind/moe/client/switch_moe.py
  30. 11 6
      hivemind/moe/server/__init__.py
  31. 3 3
      hivemind/moe/server/connection_handler.py
  32. 5 5
      hivemind/moe/server/dht_handler.py
  33. 3 3
      hivemind/moe/server/expert_backend.py
  34. 1 1
      hivemind/moe/server/expert_uid.py
  35. 1 1
      hivemind/moe/server/layers/custom_experts.py
  36. 1 1
      hivemind/moe/server/runtime.py
  37. 2 2
      hivemind/moe/server/task_pool.py
  38. 1 1
      hivemind/optim/__init__.py
  39. 1 1
      hivemind/optim/adaptive.py
  40. 2 2
      hivemind/optim/collaborative.py
  41. 3 3
      hivemind/optim/simple.py
  42. 3 3
      hivemind/utils/__init__.py
  43. 2 3
      hivemind/utils/asyncio.py
  44. 2 3
      hivemind/utils/compression.py
  45. 2 2
      hivemind/utils/grpc.py
  46. 3 4
      hivemind/utils/mpfuture.py
  47. 0 1
      hivemind/utils/networking.py
  48. 1 1
      hivemind/utils/serializer.py
  49. 1 1
      hivemind/utils/tensor_descr.py
  50. 2 1
      hivemind/utils/timed_storage.py
  51. 7 0
      pyproject.toml
  52. 1 1
      requirements-dev.txt
  53. 2 3
      tests/conftest.py
  54. 1 2
      tests/test_auth.py
  55. 1 0
      tests/test_averaging.py
  56. 1 0
      tests/test_dht.py
  57. 2 2
      tests/test_dht_crypto.py
  58. 2 2
      tests/test_dht_experts.py
  59. 1 1
      tests/test_dht_node.py
  60. 2 2
      tests/test_dht_storage.py
  61. 1 1
      tests/test_dht_validation.py
  62. 1 1
      tests/test_expert_backend.py
  63. 1 2
      tests/test_moe.py
  64. 2 1
      tests/test_p2p_daemon_bindings.py
  65. 2 2
      tests/test_routing.py
  66. 1 1
      tests/test_training.py
  67. 3 3
      tests/test_util_modules.py
  68. 1 2
      tests/test_utils/p2p_daemon.py

+ 21 - 0
.github/workflows/check-style.yml

@@ -0,0 +1,21 @@
+name: Check style
+
+on: [ push ]
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check --diff"
+          version: "21.6b0"
+  isort:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+      - uses: isort/isort-action@master

+ 0 - 13
.github/workflows/check_style.yml

@@ -1,13 +0,0 @@
-name: Check style
-
-on: [ push ]
-
-jobs:
-  black:
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: psf/black@stable
-        with:
-          options: "--check"
-          version: "21.6b0"

+ 3 - 2
CONTRIBUTING.md

@@ -34,10 +34,11 @@ with the following rules:
 
 
 ## Code style
 ## Code style
 
 
-* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
-  run `black .` in the root of the repository.
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
   cannot be longer than 119 characters.
   cannot be longer than 119 characters.
+* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for 
+  import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
+  repository.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
   output/error streams.
   output/error streams.

+ 1 - 2
benchmarks/benchmark_tensor_compression.py

@@ -4,10 +4,9 @@ import time
 import torch
 import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 1 - 2
docs/conf.py

@@ -17,9 +17,8 @@
 # sys.path.insert(0, os.path.abspath('.'))
 # sys.path.insert(0, os.path.abspath('.'))
 import sys
 import sys
 
 
-from recommonmark.transform import AutoStructify
 from recommonmark.parser import CommonMarkParser
 from recommonmark.parser import CommonMarkParser
-
+from recommonmark.transform import AutoStructify
 
 
 # -- Project information -----------------------------------------------------
 # -- Project information -----------------------------------------------------
 src_path = "../hivemind"
 src_path = "../hivemind"

+ 1 - 1
examples/albert/arguments.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Optional, List
+from typing import List, Optional
 
 
 from transformers import TrainingArguments
 from transformers import TrainingArguments
 
 

+ 3 - 3
examples/albert/run_trainer.py

@@ -11,8 +11,8 @@ import transformers
 from datasets import load_from_disk
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 from torch_optimizer import Lamb
 from torch_optimizer import Lamb
-from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
-from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
+from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
 from transformers.trainer_utils import is_main_process
@@ -21,7 +21,7 @@ import hivemind
 from hivemind.utils.compression import CompressionType
 from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
-from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
+from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 2 - 2
examples/albert/run_training_monitor.py

@@ -10,13 +10,13 @@ import requests
 import torch
 import torch
 import wandb
 import wandb
 from torch_optimizer import Lamb
 from torch_optimizer import Lamb
-from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 
 import hivemind
 import hivemind
 from hivemind.utils.compression import CompressionType
 from hivemind.utils.compression import CompressionType
 
 
 import utils
 import utils
-from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 

+ 0 - 1
examples/albert/utils.py

@@ -9,7 +9,6 @@ from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 4 - 4
hivemind/__init__.py

@@ -2,19 +2,19 @@ from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
     ExpertBackend,
     ExpertBackend,
-    Server,
-    register_expert_class,
     RemoteExpert,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
+    Server,
+    register_expert_class,
 )
 )
 from hivemind.optim import (
 from hivemind.optim import (
     CollaborativeAdaptiveOptimizer,
     CollaborativeAdaptiveOptimizer,
-    DecentralizedOptimizerBase,
     CollaborativeOptimizer,
     CollaborativeOptimizer,
+    DecentralizedAdam,
     DecentralizedOptimizer,
     DecentralizedOptimizer,
+    DecentralizedOptimizerBase,
     DecentralizedSGD,
     DecentralizedSGD,
-    DecentralizedAdam,
 )
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *

+ 4 - 4
hivemind/averaging/allreduce.py

@@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
 
 import torch
 import torch
 
 
-from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
-from hivemind.utils import get_logger
-from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
 # flavour types
 # flavour types
 GroupID = bytes
 GroupID = bytes

+ 7 - 7
hivemind/averaging/averager.py

@@ -11,12 +11,12 @@ import threading
 import weakref
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
 from dataclasses import asdict
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
+from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -24,12 +24,12 @@ from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.proto import averaging_pb2, runtime_pb2
-from hivemind.utils import MPFuture, get_logger, TensorDescriptor
-from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils.grpc import split_for_streaming, combine_from_streaming
+from hivemind.utils import MPFuture, TensorDescriptor, get_logger
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 
 # flavour types
 # flavour types
 GatheredData = Any
 GatheredData = Any

+ 3 - 3
hivemind/averaging/key_manager.py

@@ -1,14 +1,14 @@
 import asyncio
 import asyncio
-import re
 import random
 import random
-from typing import Optional, List, Tuple
+import re
+from typing import List, Optional, Tuple
 
 
 import numpy as np
 import numpy as np
 
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 
 GroupKey = str
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101

+ 2 - 1
hivemind/averaging/load_balancing.py

@@ -1,4 +1,5 @@
-from typing import Sequence, Optional, Tuple
+from typing import Optional, Sequence, Tuple
+
 import numpy as np
 import numpy as np
 import scipy.optimize
 import scipy.optimize
 
 

+ 3 - 3
hivemind/averaging/matchmaking.py

@@ -10,12 +10,12 @@ from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
-from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
+from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils.asyncio import anext
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 3 - 4
hivemind/averaging/partition.py

@@ -2,16 +2,15 @@
 Auxiliary data structures for AllReduceRunner
 Auxiliary data structures for AllReduceRunner
 """
 """
 import asyncio
 import asyncio
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
 from collections import deque
 from collections import deque
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
 
 
-import torch
 import numpy as np
 import numpy as np
+import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
-from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
 from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.asyncio import amap_in_executor
-
+from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
 DEFAULT_PART_SIZE_BYTES = 2 ** 19

+ 3 - 3
hivemind/averaging/training.py

@@ -2,13 +2,13 @@
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
-from threading import Lock, Event
-from typing import Sequence, Dict, Iterator, Optional
+from threading import Event, Lock
+from typing import Dict, Iterator, Optional, Sequence
 
 
 import torch
 import torch
 
 
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 0 - 1
hivemind/dht/crypto.py

@@ -6,7 +6,6 @@ from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 3 - 3
hivemind/dht/node.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import dataclasses
 import dataclasses
 import random
 import random
-from collections import defaultdict, Counter
+from collections import Counter, defaultdict
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from functools import partial
 from functools import partial
 from typing import (
 from typing import (
@@ -27,11 +27,11 @@ from sortedcontainers import SortedSet
 
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, DHTKey, DHTValue, Subkey, get_dht_time
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import MSGPackSerializer, SerializerBase, get_logger
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 
 

+ 6 - 6
hivemind/dht/protocol.py

@@ -2,20 +2,20 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
-from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
 
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
-from hivemind.utils import get_logger, MSGPackSerializer
-from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
+from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
 from hivemind.utils.timed_storage import (
 from hivemind.utils.timed_storage import (
-    DHTExpiration,
-    get_dht_time,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
+    DHTExpiration,
     ValueWithExpiration,
     ValueWithExpiration,
+    get_dht_time,
 )
 )
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 2 - 1
hivemind/dht/routing.py

@@ -7,7 +7,8 @@ import os
 import random
 import random
 from collections.abc import Iterable
 from collections.abc import Iterable
 from itertools import chain
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 from hivemind.utils import MSGPackSerializer, get_dht_time
 
 

+ 1 - 1
hivemind/dht/storage.py

@@ -4,7 +4,7 @@ from typing import Optional, Union
 
 
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
-from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType
 
 
 
 
 @MSGPackSerializer.ext_serializable(0x50)
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 1
hivemind/dht/traverse.py

@@ -2,7 +2,7 @@
 import asyncio
 import asyncio
 import heapq
 import heapq
 from collections import Counter
 from collections import Counter
-from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
+from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple
 
 
 from hivemind.dht.routing import DHTID
 from hivemind.dht.routing import DHTID
 
 

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -4,11 +4,11 @@ from pathlib import Path
 import configargparse
 import configargparse
 import torch
 import torch
 
 
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.moe.server import Server
 from hivemind.moe.server import Server
+from hivemind.moe.server.layers import schedule_name_to_scheduler
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.moe.server.layers import schedule_name_to_scheduler
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class

+ 8 - 8
hivemind/moe/client/beam_search.py

@@ -2,22 +2,22 @@ import asyncio
 import heapq
 import heapq
 from collections import deque
 from collections import deque
 from functools import partial
 from functools import partial
-from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration
+from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
     FLAT_EXPERT,
-    UidEndpoint,
-    Score,
-    Coordinate,
     PREFIX_PATTERN,
     PREFIX_PATTERN,
     UID_DELIMITER,
     UID_DELIMITER,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
+    Score,
+    UidEndpoint,
     is_valid_prefix,
     is_valid_prefix,
 )
 )
-from hivemind.utils import get_logger, get_dht_time, MPFuture
+from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 3 - 3
hivemind/moe/client/expert.py

@@ -1,13 +1,13 @@
 import pickle
 import pickle
-from typing import Tuple, Optional, Any, Dict
+from typing import Any, Dict, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 from hivemind.utils.grpc import ChannelCache
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 6 - 6
hivemind/moe/client/moe.py

@@ -1,8 +1,8 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import time
 import time
-from queue import Queue, Empty
-from typing import Tuple, List, Optional, Dict, Any
+from queue import Empty, Queue
+from typing import Any, Dict, List, Optional, Tuple
 
 
 import grpc
 import grpc
 import torch
 import torch
@@ -11,11 +11,11 @@ from torch.autograd.function import once_differentiable
 
 
 import hivemind
 import hivemind
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten, nested_map
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_flatten, nested_map, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 3 - 3
hivemind/moe/client/switch_moe.py

@@ -1,14 +1,14 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Tuple, List
+from typing import List, Tuple
 
 
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.moe.client.expert import RemoteExpert, DUMMY
+from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 11 - 6
hivemind/moe/server/__init__.py

@@ -5,24 +5,29 @@ import multiprocessing.synchronize
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
-from typing import Dict, List, Optional, Tuple
 from pathlib import Path
 from pathlib import Path
+from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
-from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
+from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    register_expert_class,
+    schedule_name_to_scheduler,
+)
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import grpc
 import torch
 import torch
 
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, Endpoint, nested_flatten
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 5 - 5
hivemind/moe/server/dht_handler.py

@@ -1,16 +1,16 @@
 import threading
 import threading
 from functools import partial
 from functools import partial
-from typing import Sequence, Dict, List, Tuple, Optional
+from typing import Dict, List, Optional, Sequence, Tuple
 
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
     FLAT_EXPERT,
-    Coordinate,
     UID_DELIMITER,
     UID_DELIMITER,
     UID_PATTERN,
     UID_PATTERN,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
     is_valid_uid,
     is_valid_uid,
     split_uid,
     split_uid,
 )
 )

+ 3 - 3
hivemind/moe/server/expert_backend.py

@@ -1,12 +1,12 @@
-from typing import Dict, Sequence, Any, Tuple, Union, Callable
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
 
 
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils.tensor_descr import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/moe/server/expert_uid.py

@@ -1,6 +1,6 @@
 import random
 import random
 import re
 import re
-from typing import NamedTuple, Union, Tuple, Optional, List
+from typing import List, NamedTuple, Optional, Tuple, Union
 
 
 import hivemind
 import hivemind
 from hivemind.dht import DHT
 from hivemind.dht import DHT

+ 1 - 1
hivemind/moe/server/layers/custom_experts.py

@@ -1,5 +1,5 @@
-import os
 import importlib
 import importlib
+import os
 from typing import Callable, Type
 from typing import Callable, Type
 
 
 import torch
 import torch

+ 1 - 1
hivemind/moe/server/runtime.py

@@ -4,7 +4,7 @@ import threading
 from collections import defaultdict
 from collections import defaultdict
 from itertools import chain
 from itertools import chain
 from queue import SimpleQueue
 from queue import SimpleQueue
-from selectors import DefaultSelector, EVENT_READ
+from selectors import EVENT_READ, DefaultSelector
 from statistics import mean
 from statistics import mean
 from time import time
 from time import time
 from typing import Dict, NamedTuple, Optional
 from typing import Dict, NamedTuple, Optional

+ 2 - 2
hivemind/moe/server/task_pool.py

@@ -10,12 +10,12 @@ from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from collections import namedtuple
 from concurrent.futures import Future
 from concurrent.futures import Future
 from queue import Empty
 from queue import Empty
-from typing import List, Tuple, Dict, Any, Generator
+from typing import Any, Dict, Generator, List, Tuple
 
 
 import torch
 import torch
 
 
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
-from hivemind.utils.mpfuture import MPFuture, InvalidStateError
+from hivemind.utils.mpfuture import InvalidStateError, MPFuture
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 
 import torch.optim
 import torch.optim
 
 
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind import TrainingAverager
 from hivemind import TrainingAverager
+from hivemind.optim.collaborative import CollaborativeOptimizer
 
 
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -2,8 +2,8 @@ from __future__ import annotations
 
 
 import logging
 import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
-from threading import Thread, Lock, Event
-from typing import Dict, Optional, Iterator
+from threading import Event, Lock, Thread
+from typing import Dict, Iterator, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch

+ 3 - 3
hivemind/optim/simple.py

@@ -1,13 +1,13 @@
 import time
 import time
-from threading import Thread, Lock, Event
+from threading import Event, Lock, Thread
 from typing import Optional, Sequence, Tuple
 from typing import Optional, Sequence, Tuple
 
 
 import torch
 import torch
 
 
-from hivemind.dht import DHT
 from hivemind.averaging import TrainingAverager
 from hivemind.averaging import TrainingAverager
+from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.utils import get_logger, get_dht_time
+from hivemind.utils import get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 3 - 3
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
-from hivemind.utils.serializer import SerializerBase, MSGPackSerializer
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 2 - 3
hivemind/utils/asyncio.py

@@ -1,12 +1,11 @@
-from concurrent.futures import ThreadPoolExecutor
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
 import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 
 import uvloop
 import uvloop
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 T = TypeVar("T")
 T = TypeVar("T")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 2 - 3
hivemind/utils/compression.py

@@ -1,15 +1,14 @@
 import os
 import os
+import warnings
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import Tuple, Sequence, Optional
+from typing import Optional, Sequence, Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
-import warnings
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
-
 FP32_EPS = 1e-06
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2
 NUM_BYTES_FLOAT16 = 2

+ 2 - 2
hivemind/utils/grpc.py

@@ -6,14 +6,14 @@ from __future__ import annotations
 
 
 import os
 import os
 import threading
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
+from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 
 import grpc
 import grpc
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
 from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 3 - 4
hivemind/utils/mpfuture.py

@@ -2,21 +2,20 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
-from weakref import ref
-from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
+from contextlib import nullcontext
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type
+from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from weakref import ref
 
 
 import torch  # used for py3.7-compatible shared memory
 import torch  # used for py3.7-compatible shared memory
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 # flavour types
 # flavour types

+ 0 - 1
hivemind/utils/networking.py

@@ -5,7 +5,6 @@ from typing import Optional, Sequence
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-
 Hostname, Port = str, int  # flavour types
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 LOCALHOST = "127.0.0.1"

+ 1 - 1
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
-from typing import Dict, Any
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 
 import msgpack
 import msgpack
 
 

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -1,5 +1,5 @@
 import warnings
 import warnings
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 
 
 import torch
 import torch
 
 

+ 2 - 1
hivemind/utils/timed_storage.py

@@ -1,10 +1,11 @@
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 from __future__ import annotations
 from __future__ import annotations
+
 import heapq
 import heapq
 import time
 import time
 from contextlib import contextmanager
 from contextlib import contextmanager
-from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
 from dataclasses import dataclass
 from dataclasses import dataclass
+from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
 
 
 KeyType = TypeVar("KeyType")
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")
 ValueType = TypeVar("ValueType")

+ 7 - 0
pyproject.toml

@@ -1,3 +1,10 @@
 [tool.black]
 [tool.black]
 line-length = 119
 line-length = 119
 required-version = "21.6b0"
 required-version = "21.6b0"
+
+[tool.isort]
+profile = "black"
+line_length = 119
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["arguments", "test_utils", "tests", "utils"]

+ 1 - 1
requirements-dev.txt

@@ -2,8 +2,8 @@ pytest
 pytest-forked
 pytest-forked
 pytest-asyncio
 pytest-asyncio
 pytest-cov
 pytest-cov
-codecov
 tqdm
 tqdm
 scikit-learn
 scikit-learn
 black==21.6b0
 black==21.6b0
+isort
 psutil
 psutil

+ 2 - 3
tests/conftest.py

@@ -1,13 +1,12 @@
 import gc
 import gc
-from contextlib import suppress
 import multiprocessing as mp
 import multiprocessing as mp
+from contextlib import suppress
 
 
 import psutil
 import psutil
 import pytest
 import pytest
 
 
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 2
tests/test_auth.py

@@ -5,11 +5,10 @@ import pytest
 
 
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
 from hivemind.proto.auth_pb2 import AccessToken
-from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
-
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 

+ 1 - 0
tests/test_averaging.py

@@ -12,6 +12,7 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
 
 
 
 

+ 1 - 0
tests/test_dht.py

@@ -6,6 +6,7 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
+
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
 
 
 
 

+ 2 - 2
tests/test_dht_crypto.py

@@ -1,15 +1,15 @@
 import dataclasses
 import dataclasses
-import pickle
 import multiprocessing as mp
 import multiprocessing as mp
+import pickle
 
 
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTNode
 from hivemind.dht.node import DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.timed_storage import get_dht_time
 
 
 
 
 def test_rsa_signature_validator():
 def test_rsa_signature_validator():

+ 2 - 2
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 import pytest
 
 
 import hivemind
 import hivemind
-from hivemind.dht import DHTNode
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
+from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 1 - 1
tests/test_dht_node.py

@@ -17,8 +17,8 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.p2p import P2P, PeerID
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
 
 
+from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 2 - 2
tests/test_dht_storage.py

@@ -1,8 +1,8 @@
 import time
 import time
 
 
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.dht.storage import DHTID, DHTLocalStorage, DictionaryDHTValue
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.timed_storage import get_dht_time
 
 
 
 
 def test_store():
 def test_store():

+ 1 - 1
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator
+from hivemind.dht.validation import CompositeValidator, DHTRecord
 
 
 
 
 class SchemaA(BaseModel):
 class SchemaA(BaseModel):

+ 1 - 1
tests/test_expert_backend.py

@@ -6,7 +6,7 @@ import torch
 from torch.nn import Linear
 from torch.nn import Linear
 
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.moe.server.checkpoints import store_experts, load_experts
+from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 
 EXPERT_WEIGHT_UPDATES = 3
 EXPERT_WEIGHT_UPDATES = 3

+ 1 - 2
tests/test_moe.py

@@ -4,9 +4,8 @@ import pytest
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind.moe.server import background_server, declare_experts
 from hivemind.moe.client.expert import DUMMY
 from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import layers
+from hivemind.moe.server import background_server, declare_experts, layers
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 2 - 1
tests/test_p2p_daemon_bindings.py

@@ -17,7 +17,8 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
     write_unsigned_varint,
     write_unsigned_varint,
 )
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
 from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
+
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
 
 
 
 
 def test_raise_if_failed_raises():
 def test_raise_if_failed_raises():

+ 2 - 2
tests/test_routing.py

@@ -1,10 +1,10 @@
-import random
 import heapq
 import heapq
 import operator
 import operator
+import random
 from itertools import chain, zip_longest
 from itertools import chain, zip_longest
 
 
 from hivemind import LOCALHOST
 from hivemind import LOCALHOST
-from hivemind.dht.routing import RoutingTable, DHTID
+from hivemind.dht.routing import DHTID, RoutingTable
 
 
 
 
 def test_ids_basic():
 def test_ids_basic():

+ 1 - 1
tests/test_training.py

@@ -10,7 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import DHT
 from hivemind import DHT
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedSGD, DecentralizedAdam
+from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 3 - 3
tests/test_util_modules.py

@@ -12,9 +12,9 @@ import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer, ValueWithExpiration, HeapEntry, DHTExpiration
-from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
 
 
 
 

+ 1 - 2
tests/test_utils/p2p_daemon.py

@@ -6,14 +6,13 @@ import time
 import uuid
 import uuid
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 from typing import NamedTuple
 from typing import NamedTuple
-from pkg_resources import resource_filename
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
+from pkg_resources import resource_filename
 
 
 from hivemind import find_open_port
 from hivemind import find_open_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
 
-
 TIMEOUT_DURATION = 30  # seconds
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")