Эх сурвалжийг харах

Reorder imports with isort (#326)

* Apply isort everywhere

* Config isort in CI

* Update pyproject.toml

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Michael Diskin 4 жил өмнө
parent
commit
bedfa6eefb
68 өөрчлөгдсөн 179 нэмэгдсэн , 165 устгасан
  1. 21 0
      .github/workflows/check-style.yml
  2. 0 13
      .github/workflows/check_style.yml
  3. 3 2
      CONTRIBUTING.md
  4. 1 2
      benchmarks/benchmark_tensor_compression.py
  5. 1 2
      docs/conf.py
  6. 1 1
      examples/albert/arguments.py
  7. 3 3
      examples/albert/run_trainer.py
  8. 2 2
      examples/albert/run_training_monitor.py
  9. 0 1
      examples/albert/utils.py
  10. 4 4
      hivemind/__init__.py
  11. 4 4
      hivemind/averaging/allreduce.py
  12. 7 7
      hivemind/averaging/averager.py
  13. 3 3
      hivemind/averaging/key_manager.py
  14. 2 1
      hivemind/averaging/load_balancing.py
  15. 3 3
      hivemind/averaging/matchmaking.py
  16. 3 4
      hivemind/averaging/partition.py
  17. 3 3
      hivemind/averaging/training.py
  18. 0 1
      hivemind/dht/crypto.py
  19. 3 3
      hivemind/dht/node.py
  20. 6 6
      hivemind/dht/protocol.py
  21. 2 1
      hivemind/dht/routing.py
  22. 1 1
      hivemind/dht/storage.py
  23. 1 1
      hivemind/dht/traverse.py
  24. 2 2
      hivemind/hivemind_cli/run_server.py
  25. 1 1
      hivemind/moe/__init__.py
  26. 8 8
      hivemind/moe/client/beam_search.py
  27. 3 3
      hivemind/moe/client/expert.py
  28. 6 6
      hivemind/moe/client/moe.py
  29. 3 3
      hivemind/moe/client/switch_moe.py
  30. 11 6
      hivemind/moe/server/__init__.py
  31. 3 3
      hivemind/moe/server/connection_handler.py
  32. 5 5
      hivemind/moe/server/dht_handler.py
  33. 3 3
      hivemind/moe/server/expert_backend.py
  34. 1 1
      hivemind/moe/server/expert_uid.py
  35. 1 1
      hivemind/moe/server/layers/custom_experts.py
  36. 1 1
      hivemind/moe/server/runtime.py
  37. 2 2
      hivemind/moe/server/task_pool.py
  38. 1 1
      hivemind/optim/__init__.py
  39. 1 1
      hivemind/optim/adaptive.py
  40. 2 2
      hivemind/optim/collaborative.py
  41. 3 3
      hivemind/optim/simple.py
  42. 3 3
      hivemind/utils/__init__.py
  43. 2 3
      hivemind/utils/asyncio.py
  44. 2 3
      hivemind/utils/compression.py
  45. 2 2
      hivemind/utils/grpc.py
  46. 3 4
      hivemind/utils/mpfuture.py
  47. 0 1
      hivemind/utils/networking.py
  48. 1 1
      hivemind/utils/serializer.py
  49. 1 1
      hivemind/utils/tensor_descr.py
  50. 2 1
      hivemind/utils/timed_storage.py
  51. 7 0
      pyproject.toml
  52. 1 1
      requirements-dev.txt
  53. 2 3
      tests/conftest.py
  54. 1 2
      tests/test_auth.py
  55. 1 0
      tests/test_averaging.py
  56. 1 0
      tests/test_dht.py
  57. 2 2
      tests/test_dht_crypto.py
  58. 2 2
      tests/test_dht_experts.py
  59. 1 1
      tests/test_dht_node.py
  60. 2 2
      tests/test_dht_storage.py
  61. 1 1
      tests/test_dht_validation.py
  62. 1 1
      tests/test_expert_backend.py
  63. 1 2
      tests/test_moe.py
  64. 2 1
      tests/test_p2p_daemon_bindings.py
  65. 2 2
      tests/test_routing.py
  66. 1 1
      tests/test_training.py
  67. 3 3
      tests/test_util_modules.py
  68. 1 2
      tests/test_utils/p2p_daemon.py

+ 21 - 0
.github/workflows/check-style.yml

@@ -0,0 +1,21 @@
+name: Check style
+
+on: [ push ]
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check --diff"
+          version: "21.6b0"
+  isort:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+      - uses: isort/isort-action@master

+ 0 - 13
.github/workflows/check_style.yml

@@ -1,13 +0,0 @@
-name: Check style
-
-on: [ push ]
-
-jobs:
-  black:
-    runs-on: ubuntu-latest
-    steps:
-      - uses: actions/checkout@v2
-      - uses: psf/black@stable
-        with:
-          options: "--check"
-          version: "21.6b0"

+ 3 - 2
CONTRIBUTING.md

@@ -34,10 +34,11 @@ with the following rules:
 
 ## Code style
 
-* We use [black](https://github.com/psf/black) for code formatting. Before submitting a PR, make sure to install and
-  run `black .` in the root of the repository.
 * The code must follow [PEP8](https://www.python.org/dev/peps/pep-0008/) unless absolutely necessary. Also, each line
   cannot be longer than 119 characters.
+* We use [black](https://github.com/psf/black) for code formatting and [isort](https://github.com/PyCQA/isort) for 
+  import sorting. Before submitting a PR, make sure to install and run `black .` and `isort .` in the root of the
+  repository.
 * We highly encourage the use of [typing](https://docs.python.org/3/library/typing.html) where applicable.
 * Use `get_logger` from `hivemind.utils.logging` to log any information instead of `print`ing directly to standard
   output/error streams.

+ 1 - 2
benchmarks/benchmark_tensor_compression.py

@@ -4,10 +4,9 @@ import time
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 1 - 2
docs/conf.py

@@ -17,9 +17,8 @@
 # sys.path.insert(0, os.path.abspath('.'))
 import sys
 
-from recommonmark.transform import AutoStructify
 from recommonmark.parser import CommonMarkParser
-
+from recommonmark.transform import AutoStructify
 
 # -- Project information -----------------------------------------------------
 src_path = "../hivemind"

+ 1 - 1
examples/albert/arguments.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
-from typing import Optional, List
+from typing import List, Optional
 
 from transformers import TrainingArguments
 

+ 3 - 3
examples/albert/run_trainer.py

@@ -11,8 +11,8 @@ import transformers
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch_optimizer import Lamb
-from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
-from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
+from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
 from transformers.trainer_utils import is_main_process
@@ -21,7 +21,7 @@ import hivemind
 from hivemind.utils.compression import CompressionType
 
 import utils
-from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
+from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
 
 logger = logging.getLogger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 2 - 2
examples/albert/run_training_monitor.py

@@ -10,13 +10,13 @@ import requests
 import torch
 import wandb
 from torch_optimizer import Lamb
-from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
 
 import hivemind
 from hivemind.utils.compression import CompressionType
 
 import utils
-from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
+from arguments import AveragerArguments, BaseTrainingArguments, CollaborativeOptimizerArguments
 
 logger = logging.getLogger(__name__)
 

+ 0 - 1
examples/albert/utils.py

@@ -9,7 +9,6 @@ from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import RecordValidatorBase
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 4 - 4
hivemind/__init__.py

@@ -2,19 +2,19 @@ from hivemind.averaging import DecentralizedAverager, TrainingAverager
 from hivemind.dht import DHT
 from hivemind.moe import (
     ExpertBackend,
-    Server,
-    register_expert_class,
     RemoteExpert,
     RemoteMixtureOfExperts,
     RemoteSwitchMixtureOfExperts,
+    Server,
+    register_expert_class,
 )
 from hivemind.optim import (
     CollaborativeAdaptiveOptimizer,
-    DecentralizedOptimizerBase,
     CollaborativeOptimizer,
+    DecentralizedAdam,
     DecentralizedOptimizer,
+    DecentralizedOptimizerBase,
     DecentralizedSGD,
-    DecentralizedAdam,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *

+ 4 - 4
hivemind/averaging/allreduce.py

@@ -4,12 +4,12 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
 import torch
 
-from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
-from hivemind.utils import get_logger
-from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor, asingle
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.proto import averaging_pb2
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, asingle
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
 GroupID = bytes

+ 7 - 7
hivemind/averaging/averager.py

@@ -11,12 +11,12 @@ import threading
 import weakref
 from concurrent.futures.thread import ThreadPoolExecutor
 from dataclasses import asdict
-from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 
 import numpy as np
 import torch
 
-from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
+from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -24,12 +24,12 @@ from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
-from hivemind.utils import MPFuture, get_logger, TensorDescriptor
-from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
-from hivemind.utils.grpc import split_for_streaming, combine_from_streaming
+from hivemind.utils import MPFuture, TensorDescriptor, get_logger
+from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 # flavour types
 GatheredData = Any

+ 3 - 3
hivemind/averaging/key_manager.py

@@ -1,14 +1,14 @@
 import asyncio
-import re
 import random
-from typing import Optional, List, Tuple
+import re
+from typing import List, Optional, Tuple
 
 import numpy as np
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.p2p import PeerID
-from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration
+from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101

+ 2 - 1
hivemind/averaging/load_balancing.py

@@ -1,4 +1,5 @@
-from typing import Sequence, Optional, Tuple
+from typing import Optional, Sequence, Tuple
+
 import numpy as np
 import scipy.optimize
 

+ 3 - 3
hivemind/averaging/matchmaking.py

@@ -10,12 +10,12 @@ from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 from hivemind.averaging.group_info import GroupInfo
-from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
-from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
-from hivemind.utils.asyncio import anext
 from hivemind.proto import averaging_pb2
+from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils.asyncio import anext
 
 logger = get_logger(__name__)
 

+ 3 - 4
hivemind/averaging/partition.py

@@ -2,16 +2,15 @@
 Auxiliary data structures for AllReduceRunner
 """
 import asyncio
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
 from collections import deque
+from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, TypeVar, Union
 
-import torch
 import numpy as np
+import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType, Tensor
-from hivemind.utils.compression import serialize_torch_tensor, get_nbytes_per_value
 from hivemind.utils.asyncio import amap_in_executor
-
+from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19

+ 3 - 3
hivemind/averaging/training.py

@@ -2,13 +2,13 @@
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
-from threading import Lock, Event
-from typing import Sequence, Dict, Iterator, Optional
+from threading import Event, Lock
+from typing import Dict, Iterator, Optional, Sequence
 
 import torch
 
 from hivemind.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger
+from hivemind.utils import get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 

+ 0 - 1
hivemind/dht/crypto.py

@@ -6,7 +6,6 @@ from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import MSGPackSerializer, get_logger
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 
-
 logger = get_logger(__name__)
 
 

+ 3 - 3
hivemind/dht/node.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import asyncio
 import dataclasses
 import random
-from collections import defaultdict, Counter
+from collections import Counter, defaultdict
 from dataclasses import dataclass, field
 from functools import partial
 from typing import (
@@ -27,11 +27,11 @@ from sortedcontainers import SortedSet
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, DHTKey, DHTValue, Subkey, get_dht_time
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import MSGPackSerializer, SerializerBase, get_logger
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
 

+ 6 - 6
hivemind/dht/protocol.py

@@ -2,20 +2,20 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, RoutingTable, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
-from hivemind.utils import get_logger, MSGPackSerializer
-from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
+from hivemind.utils import MSGPackSerializer, get_logger
+from hivemind.utils.auth import AuthorizerBase, AuthRole, AuthRPCWrapper
 from hivemind.utils.timed_storage import (
-    DHTExpiration,
-    get_dht_time,
     MAX_DHT_TIME_DISCREPANCY_SECONDS,
+    DHTExpiration,
     ValueWithExpiration,
+    get_dht_time,
 )
 
 logger = get_logger(__name__)

+ 2 - 1
hivemind/dht/routing.py

@@ -7,7 +7,8 @@ import os
 import random
 from collections.abc import Iterable
 from itertools import chain
-from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
+from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
+
 from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 

+ 1 - 1
hivemind/dht/storage.py

@@ -4,7 +4,7 @@ from typing import Optional, Union
 
 from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
-from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
+from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType
 
 
 @MSGPackSerializer.ext_serializable(0x50)

+ 1 - 1
hivemind/dht/traverse.py

@@ -2,7 +2,7 @@
 import asyncio
 import heapq
 from collections import Counter
-from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
+from typing import Any, Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple
 
 from hivemind.dht.routing import DHTID
 

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -4,11 +4,11 @@ from pathlib import Path
 import configargparse
 import torch
 
-from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.moe.server import Server
+from hivemind.moe.server.layers import schedule_name_to_scheduler
+from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
-from hivemind.moe.server.layers import schedule_name_to_scheduler
 
 logger = get_logger(__name__)
 

+ 1 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, register_expert_class, get_experts, declare_experts
+from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class

+ 8 - 8
hivemind/moe/client/beam_search.py

@@ -2,22 +2,22 @@ import asyncio
 import heapq
 from collections import deque
 from functools import partial
-from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
+from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration
+from hivemind.dht import DHT, DHTExpiration, DHTNode
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
-    UidEndpoint,
-    Score,
-    Coordinate,
     PREFIX_PATTERN,
     UID_DELIMITER,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
+    Score,
+    UidEndpoint,
     is_valid_prefix,
 )
-from hivemind.utils import get_logger, get_dht_time, MPFuture
+from hivemind.utils import MPFuture, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 

+ 3 - 3
hivemind/moe/client/expert.py

@@ -1,13 +1,13 @@
 import pickle
-from typing import Tuple, Optional, Any, Dict
+from typing import Any, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import Endpoint, nested_compare, nested_flatten, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import ChannelCache
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert

+ 6 - 6
hivemind/moe/client/moe.py

@@ -1,8 +1,8 @@
 from __future__ import annotations
 
 import time
-from queue import Queue, Empty
-from typing import Tuple, List, Optional, Dict, Any
+from queue import Empty, Queue
+from typing import Any, Dict, List, Optional, Tuple
 
 import grpc
 import torch
@@ -11,11 +11,11 @@ from torch.autograd.function import once_differentiable
 
 import hivemind
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import RemoteExpert, DUMMY, _get_expert_stub
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten, nested_map
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import nested_flatten, nested_map, nested_pack
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 3 - 3
hivemind/moe/client/switch_moe.py

@@ -1,14 +1,14 @@
 from __future__ import annotations
 
-from typing import Tuple, List
+from typing import List, Tuple
 
 import grpc
 import torch
 
-from hivemind.moe.client.expert import RemoteExpert, DUMMY
+from hivemind.moe.client.expert import DUMMY, RemoteExpert
 from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
 from hivemind.moe.server.expert_uid import UID_DELIMITER
-from hivemind.utils import nested_pack, nested_flatten
+from hivemind.utils import nested_flatten, nested_pack
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 11 - 6
hivemind/moe/server/__init__.py

@@ -5,24 +5,29 @@ import multiprocessing.synchronize
 import threading
 from contextlib import contextmanager
 from functools import partial
-from typing import Dict, List, Optional, Tuple
 from pathlib import Path
+from typing import Dict, List, Optional, Tuple
 
 import torch
 from multiaddr import Multiaddr
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
+from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
-from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
+from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.moe.server.layers import (
+    add_custom_models_from_file,
+    name_to_block,
+    name_to_input,
+    register_expert_class,
+    schedule_name_to_scheduler,
+)
 from hivemind.moe.server.runtime import Runtime
-from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
 
 logger = get_logger(__name__)
 

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -6,11 +6,11 @@ from typing import Dict
 import grpc
 import torch
 
-from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
 from hivemind.moe.server.expert_backend import ExpertBackend
-from hivemind.utils import get_logger, Endpoint, nested_flatten
+from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
+from hivemind.utils import Endpoint, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
 
 logger = get_logger(__name__)

+ 5 - 5
hivemind/moe/server/dht_handler.py

@@ -1,16 +1,16 @@
 import threading
 from functools import partial
-from typing import Sequence, Dict, List, Tuple, Optional
+from typing import Dict, List, Optional, Sequence, Tuple
 
-from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
+from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.server.expert_uid import (
-    ExpertUID,
-    ExpertPrefix,
     FLAT_EXPERT,
-    Coordinate,
     UID_DELIMITER,
     UID_PATTERN,
+    Coordinate,
+    ExpertPrefix,
+    ExpertUID,
     is_valid_uid,
     split_uid,
 )

+ 3 - 3
hivemind/moe/server/expert_backend.py

@@ -1,12 +1,12 @@
-from typing import Dict, Sequence, Any, Tuple, Union, Callable
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
 
 import torch
 from torch import nn
 
 from hivemind.moe.server.task_pool import TaskPool
-from hivemind.utils.tensor_descr import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
-from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
+from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
 
 logger = get_logger(__name__)
 

+ 1 - 1
hivemind/moe/server/expert_uid.py

@@ -1,6 +1,6 @@
 import random
 import re
-from typing import NamedTuple, Union, Tuple, Optional, List
+from typing import List, NamedTuple, Optional, Tuple, Union
 
 import hivemind
 from hivemind.dht import DHT

+ 1 - 1
hivemind/moe/server/layers/custom_experts.py

@@ -1,5 +1,5 @@
-import os
 import importlib
+import os
 from typing import Callable, Type
 
 import torch

+ 1 - 1
hivemind/moe/server/runtime.py

@@ -4,7 +4,7 @@ import threading
 from collections import defaultdict
 from itertools import chain
 from queue import SimpleQueue
-from selectors import DefaultSelector, EVENT_READ
+from selectors import EVENT_READ, DefaultSelector
 from statistics import mean
 from time import time
 from typing import Dict, NamedTuple, Optional

+ 2 - 2
hivemind/moe/server/task_pool.py

@@ -10,12 +10,12 @@ from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
-from typing import List, Tuple, Dict, Any, Generator
+from typing import Any, Dict, Generator, List, Tuple
 
 import torch
 
 from hivemind.utils import get_logger
-from hivemind.utils.mpfuture import MPFuture, InvalidStateError
+from hivemind.utils.mpfuture import InvalidStateError, MPFuture
 
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind import TrainingAverager
+from hivemind.optim.collaborative import CollaborativeOptimizer
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -2,8 +2,8 @@ from __future__ import annotations
 
 import logging
 from dataclasses import dataclass
-from threading import Thread, Lock, Event
-from typing import Dict, Optional, Iterator
+from threading import Event, Lock, Thread
+from typing import Dict, Iterator, Optional
 
 import numpy as np
 import torch

+ 3 - 3
hivemind/optim/simple.py

@@ -1,13 +1,13 @@
 import time
-from threading import Thread, Lock, Event
+from threading import Event, Lock, Thread
 from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.dht import DHT
 from hivemind.averaging import TrainingAverager
+from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.utils import get_logger, get_dht_time
+from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
 

+ 3 - 3
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
-from hivemind.utils.serializer import SerializerBase, MSGPackSerializer
-from hivemind.utils.tensor_descr import TensorDescriptor, BatchTensorDescriptor
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 2 - 3
hivemind/utils/asyncio.py

@@ -1,12 +1,11 @@
-from concurrent.futures import ThreadPoolExecutor
-from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
 import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
 from hivemind.utils.logging import get_logger
 
-
 T = TypeVar("T")
 logger = get_logger(__name__)
 

+ 2 - 3
hivemind/utils/compression.py

@@ -1,15 +1,14 @@
 import os
+import warnings
 from concurrent.futures import ThreadPoolExecutor
-from typing import Tuple, Sequence, Optional
+from typing import Optional, Sequence, Tuple
 
 import numpy as np
 import torch
-import warnings
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 
-
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT16 = 2

+ 2 - 2
hivemind/utils/grpc.py

@@ -6,14 +6,14 @@ from __future__ import annotations
 
 import os
 import threading
-from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable
+from typing import Any, Dict, Iterable, Iterator, NamedTuple, Optional, Tuple, Type, TypeVar, Union
 
 import grpc
 
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, get_dht_time, ValueWithExpiration
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
 
 logger = get_logger(__name__)
 

+ 3 - 4
hivemind/utils/mpfuture.py

@@ -2,21 +2,20 @@ from __future__ import annotations
 
 import asyncio
 import concurrent.futures._base as base
-from weakref import ref
-from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing.connection
 import os
 import threading
 import uuid
+from contextlib import nullcontext
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type
+from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from weakref import ref
 
 import torch  # used for py3.7-compatible shared memory
 
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 # flavour types

+ 0 - 1
hivemind/utils/networking.py

@@ -5,7 +5,6 @@ from typing import Optional, Sequence
 
 from multiaddr import Multiaddr
 
-
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"

+ 1 - 1
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
-from typing import Dict, Any
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 import msgpack
 

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -1,5 +1,5 @@
 import warnings
-from dataclasses import dataclass, asdict
+from dataclasses import asdict, dataclass
 
 import torch
 

+ 2 - 1
hivemind/utils/timed_storage.py

@@ -1,10 +1,11 @@
 """ A dictionary-like storage that stores items until a specified expiration time or up to a limited size """
 from __future__ import annotations
+
 import heapq
 import time
 from contextlib import contextmanager
-from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
 from dataclasses import dataclass
+from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
 
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")

+ 7 - 0
pyproject.toml

@@ -1,3 +1,10 @@
 [tool.black]
 line-length = 119
 required-version = "21.6b0"
+
+[tool.isort]
+profile = "black"
+line_length = 119
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["arguments", "test_utils", "tests", "utils"]

+ 1 - 1
requirements-dev.txt

@@ -2,8 +2,8 @@ pytest
 pytest-forked
 pytest-asyncio
 pytest-cov
-codecov
 tqdm
 scikit-learn
 black==21.6b0
+isort
 psutil

+ 2 - 3
tests/conftest.py

@@ -1,13 +1,12 @@
 import gc
-from contextlib import suppress
 import multiprocessing as mp
+from contextlib import suppress
 
 import psutil
 import pytest
 
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
 from hivemind.utils.logging import get_logger
-
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 
 logger = get_logger(__name__)
 

+ 1 - 2
tests/test_auth.py

@@ -5,11 +5,10 @@ import pytest
 
 from hivemind.proto import dht_pb2
 from hivemind.proto.auth_pb2 import AccessToken
-from hivemind.utils.auth import AuthRPCWrapper, AuthRole, TokenAuthorizerBase
+from hivemind.utils.auth import AuthRole, AuthRPCWrapper, TokenAuthorizerBase
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 1 - 0
tests/test_averaging.py

@@ -12,6 +12,7 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
+
 from test_utils.dht_swarms import launch_dht_instances
 
 

+ 1 - 0
tests/test_dht.py

@@ -6,6 +6,7 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
+
 from test_utils.dht_swarms import launch_dht_instances
 
 

+ 2 - 2
tests/test_dht_crypto.py

@@ -1,15 +1,15 @@
 import dataclasses
-import pickle
 import multiprocessing as mp
+import pickle
 
 import pytest
 
 import hivemind
-from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.node import DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.timed_storage import get_dht_time
 
 
 def test_rsa_signature_validator():

+ 2 - 2
tests/test_dht_experts.py

@@ -6,11 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-from hivemind.dht import DHTNode
 from hivemind import LOCALHOST
+from hivemind.dht import DHTNode
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.server import declare_experts, get_experts
-from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_prefix, is_valid_uid, split_uid
 
 
 @pytest.mark.forked

+ 1 - 1
tests/test_dht_node.py

@@ -17,8 +17,8 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.p2p import P2P, PeerID
 from hivemind.utils.logging import get_logger
-from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
 
+from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
 
 logger = get_logger(__name__)
 

+ 2 - 2
tests/test_dht_storage.py

@@ -1,8 +1,8 @@
 import time
 
-from hivemind.utils.timed_storage import get_dht_time
-from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
+from hivemind.dht.storage import DHTID, DHTLocalStorage, DictionaryDHTValue
 from hivemind.utils.serializer import MSGPackSerializer
+from hivemind.utils.timed_storage import get_dht_time
 
 
 def test_store():

+ 1 - 1
tests/test_dht_validation.py

@@ -9,7 +9,7 @@ from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.dht.validation import DHTRecord, CompositeValidator
+from hivemind.dht.validation import CompositeValidator, DHTRecord
 
 
 class SchemaA(BaseModel):

+ 1 - 1
tests/test_expert_backend.py

@@ -6,7 +6,7 @@ import torch
 from torch.nn import Linear
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.moe.server.checkpoints import store_experts, load_experts
+from hivemind.moe.server.checkpoints import load_experts, store_experts
 from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 EXPERT_WEIGHT_UPDATES = 3

+ 1 - 2
tests/test_moe.py

@@ -4,9 +4,8 @@ import pytest
 import torch
 
 import hivemind
-from hivemind.moe.server import background_server, declare_experts
 from hivemind.moe.client.expert import DUMMY
-from hivemind.moe.server import layers
+from hivemind.moe.server import background_server, declare_experts, layers
 
 
 @pytest.mark.forked

+ 2 - 1
tests/test_p2p_daemon_bindings.py

@@ -17,7 +17,8 @@ from hivemind.p2p.p2p_daemon_bindings.utils import (
     write_unsigned_varint,
 )
 from hivemind.proto import p2pd_pb2 as p2pd_pb
-from test_utils.p2p_daemon import make_p2pd_pair_ip4, connect_safe
+
+from test_utils.p2p_daemon import connect_safe, make_p2pd_pair_ip4
 
 
 def test_raise_if_failed_raises():

+ 2 - 2
tests/test_routing.py

@@ -1,10 +1,10 @@
-import random
 import heapq
 import operator
+import random
 from itertools import chain, zip_longest
 
 from hivemind import LOCALHOST
-from hivemind.dht.routing import RoutingTable, DHTID
+from hivemind.dht.routing import DHTID, RoutingTable
 
 
 def test_ids_basic():

+ 1 - 1
tests/test_training.py

@@ -10,7 +10,7 @@ from sklearn.datasets import load_digits
 from hivemind import DHT
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedSGD, DecentralizedAdam
+from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_util_modules.py

@@ -12,9 +12,9 @@ import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-from hivemind.utils import MSGPackSerializer, ValueWithExpiration, HeapEntry, DHTExpiration
-from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
+from hivemind.utils.asyncio import achain, aenumerate, aiter, amap_in_executor, anext, azip
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
 

+ 1 - 2
tests/test_utils/p2p_daemon.py

@@ -6,14 +6,13 @@ import time
 import uuid
 from contextlib import asynccontextmanager
 from typing import NamedTuple
-from pkg_resources import resource_filename
 
 from multiaddr import Multiaddr, protocols
+from pkg_resources import resource_filename
 
 from hivemind import find_open_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
-
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")