Kaynağa Gözat

Split hivemind.client into hivemind.averaging and hivemind.moe (#304)

* Split hivemind.client into hivemind.averaging and hivemind.moe

* Reduce the number of wildcard imports, update docs
Max Ryabinin 4 yıl önce
ebeveyn
işleme
5233b6c085
70 değiştirilmiş dosya ile 254 ekleme ve 238 silme
  1. 4 5
      benchmarks/benchmark_averaging.py
  2. 8 6
      benchmarks/benchmark_dht.py
  3. 9 10
      benchmarks/benchmark_throughput.py
  4. 15 0
      docs/modules/averaging.rst
  5. 3 8
      docs/modules/client.rst
  6. 1 0
      docs/modules/index.rst
  7. 9 7
      docs/modules/server.rst
  8. 1 1
      docs/user/benchmarks.md
  9. 1 1
      docs/user/quickstart.md
  10. 7 5
      hivemind/__init__.py
  11. 2 0
      hivemind/averaging/__init__.py
  12. 1 1
      hivemind/averaging/allreduce.py
  13. 12 12
      hivemind/averaging/averager.py
  14. 0 0
      hivemind/averaging/group_info.py
  15. 1 1
      hivemind/averaging/key_manager.py
  16. 0 0
      hivemind/averaging/load_balancing.py
  17. 6 6
      hivemind/averaging/matchmaking.py
  18. 0 0
      hivemind/averaging/partition.py
  19. 2 2
      hivemind/averaging/training.py
  20. 0 5
      hivemind/client/__init__.py
  21. 3 2
      hivemind/dht/__init__.py
  22. 8 6
      hivemind/dht/node.py
  23. 4 4
      hivemind/dht/protocol.py
  24. 2 1
      hivemind/dht/routing.py
  25. 1 4
      hivemind/dht/schema.py
  26. 4 4
      hivemind/dht/storage.py
  27. 2 2
      hivemind/hivemind_cli/run_server.py
  28. 2 0
      hivemind/moe/__init__.py
  29. 3 0
      hivemind/moe/client/__init__.py
  30. 4 4
      hivemind/moe/client/beam_search.py
  31. 0 0
      hivemind/moe/client/expert.py
  32. 3 3
      hivemind/moe/client/moe.py
  33. 3 3
      hivemind/moe/client/switch_moe.py
  34. 9 10
      hivemind/moe/server/__init__.py
  35. 1 1
      hivemind/moe/server/checkpoints.py
  36. 1 1
      hivemind/moe/server/connection_handler.py
  37. 3 3
      hivemind/moe/server/dht_handler.py
  38. 1 1
      hivemind/moe/server/expert_backend.py
  39. 1 1
      hivemind/moe/server/expert_uid.py
  40. 9 0
      hivemind/moe/server/layers/__init__.py
  41. 1 1
      hivemind/moe/server/layers/common.py
  42. 1 1
      hivemind/moe/server/layers/custom_experts.py
  43. 1 1
      hivemind/moe/server/layers/dropout.py
  44. 0 0
      hivemind/moe/server/layers/lr_schedule.py
  45. 1 1
      hivemind/moe/server/runtime.py
  46. 0 0
      hivemind/moe/server/task_pool.py
  47. 1 2
      hivemind/optim/__init__.py
  48. 0 2
      hivemind/optim/base.py
  49. 3 4
      hivemind/optim/collaborative.py
  50. 1 1
      hivemind/optim/simple.py
  51. 0 2
      hivemind/p2p/p2p_daemon.py
  52. 0 9
      hivemind/server/layers/__init__.py
  53. 2 7
      hivemind/utils/auth.py
  54. 0 1
      hivemind/utils/crypto.py
  55. 1 1
      requirements-dev.txt
  56. 1 1
      requirements-docs.txt
  57. 7 6
      tests/test_allreduce.py
  58. 38 33
      tests/test_averaging.py
  59. 2 1
      tests/test_custom_experts.py
  60. 6 5
      tests/test_dht_crypto.py
  61. 11 10
      tests/test_dht_experts.py
  62. 10 10
      tests/test_dht_node.py
  63. 1 1
      tests/test_dht_schema.py
  64. 1 1
      tests/test_dht_storage.py
  65. 2 2
      tests/test_expert_backend.py
  66. 7 7
      tests/test_moe.py
  67. 2 2
      tests/test_p2p_daemon.py
  68. 4 2
      tests/test_training.py
  69. 3 3
      tests/test_util_modules.py
  70. 1 1
      tests/test_utils/custom_networks.py

+ 4 - 5
benchmarks/benchmark_averaging.py

@@ -1,14 +1,13 @@
+import argparse
 import math
-import time
 import threading
-import argparse
+import time
 
 import torch
 
 import hivemind
-from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
 from hivemind.proto import runtime_pb2
-
+from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
 
 logger = get_logger(__name__)
 
@@ -50,7 +49,7 @@ def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
                            initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
                            start=True)
         initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
-        averager = hivemind.DecentralizedAverager(
+        averager = hivemind.averaging.DecentralizedAverager(
             peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
             compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
             averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)

+ 8 - 6
benchmarks/benchmark_dht.py

@@ -5,7 +5,7 @@ import time
 from tqdm import trange
 
 import hivemind
-import hivemind.server.expert_uid
+from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
 
 logger = hivemind.get_logger(__name__)
@@ -43,8 +43,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
-                                            expiration=expiration)
+        store_ok = declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
+                                   expiration=expiration)
         successes = store_ok.values()
         total_store_time += time.perf_counter() - store_start
 
@@ -52,7 +52,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
         successful_stores += sum(successes)
         time.sleep(wait_after_request)
 
-    logger.info(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
+    logger.info(
+        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
     logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
 
@@ -63,7 +64,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
     for start in trange(0, len(expert_uids), expert_batch_size):
         get_start = time.perf_counter()
-        get_result = hivemind.get_experts(get_peer, expert_uids[start: start + expert_batch_size])
+        get_result = get_experts(get_peer, expert_uids[start: start + expert_batch_size])
         total_get_time += time.perf_counter() - get_start
 
         for i, expert in enumerate(get_result):
@@ -74,7 +75,8 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     if time.perf_counter() - benchmark_started > expiration:
         logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
 
-    logger.info(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
+    logger.info(
+        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
     logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
 
     alive_peers = [peer.is_alive() for peer in peers]

+ 9 - 10
benchmarks/benchmark_throughput.py

@@ -8,11 +8,10 @@ import torch
 
 import hivemind
 from hivemind import find_open_port
-from hivemind.server import layers
+from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 
-
 logger = get_logger(__name__)
 
 
@@ -88,8 +87,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
                                                            max_batch_size=max_batch_size,
                                                            )
         timestamps['created_experts'] = time.perf_counter()
-        server = hivemind.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
-                                 num_connection_handlers=num_handlers, device=device)
+        server = hivemind.moe.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
+                                     num_connection_handlers=num_handlers, device=device)
         server.start()
         server.ready.wait()
         timestamps['server_ready'] = time.perf_counter()
@@ -116,18 +115,18 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
     total_examples = batch_size * num_clients * num_batches_per_client
 
     logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
-    logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
-          f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
+    logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
+                f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
     logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
-          f"batch_size={batch_size}, backprop={backprop}")
+                f"batch_size={batch_size}, backprop={backprop}")
 
     logger.info("Results: ")
     logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
-          f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
-          f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
+                f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
+                f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
     logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
     logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
-          f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
+                f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
     logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
     if benchmarking_failed.is_set():
         logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")

+ 15 - 0
docs/modules/averaging.rst

@@ -0,0 +1,15 @@
+**hivemind.averaging**
+======================
+
+.. automodule:: hivemind.averaging
+
+.. currentmodule:: hivemind.averaging
+
+.. raw:: html
+
+  This module lets you average tensors in a decentralized manner.
+
+.. autoclass:: DecentralizedAverager
+   :members:
+   :member-order: bysource
+   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group

+ 3 - 8
docs/modules/client.rst

@@ -1,9 +1,9 @@
 **hivemind.client**
 ====================
 
-.. automodule:: hivemind.client
+.. automodule:: hivemind.moe.client
 
-.. currentmodule:: hivemind.client
+.. currentmodule:: hivemind.moe.client
 
 .. raw:: html
 
@@ -20,9 +20,4 @@
 
 .. autoclass:: RemoteSwitchMixtureOfExperts
    :members:
-   :member-order: bysource
-
-.. autoclass:: DecentralizedAverager
-   :members:
-   :member-order: bysource
-   :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group
+   :member-order: bysource

+ 1 - 0
docs/modules/index.rst

@@ -5,6 +5,7 @@
 .. toctree::
    :maxdepth: 2
 
+   averaging
    client
    server
    dht

+ 9 - 7
docs/modules/server.rst

@@ -1,24 +1,24 @@
-**hivemind.server**
+**hivemind.moe.server**
 ========================================
 
 A hivemind server hosts one or several experts and processes incoming requests to those experts. It periodically
 re-publishes these experts to the dht via a dedicated **hivemind.dht.DHT** peer that runs in background.
-The experts can be accessed directly as **hivemind.client.RemoteExpert("addr:port", "expert.uid.here")**
-or as a part of **hivemind.client.RemoteMixtureOfExperts** that finds the most suitable experts across the DHT.
+The experts can be accessed directly as **hivemind.moe.client.RemoteExpert("addr:port", "expert.uid.here")**
+or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the most suitable experts across the DHT.
 
-The hivemind.server module is organized as follows:
+The hivemind.moe.server module is organized as follows:
 
 - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute.
 - Runtime_ balances the device (GPU) usage between several ExpertBackend_ instances that each service one expert.
 - ExpertBackend_ is a wrapper for `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_ \
-  that can be accessed by remote clients. It has two TaskPool_ -s for forward and backward requests.
+  that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests.
 - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \
   and offers those batches to Runtime_ for processing.
 
 
-.. automodule:: hivemind.server
+.. automodule:: hivemind.moe.server
 
-.. currentmodule:: hivemind.server
+.. currentmodule:: hivemind.moe.server
 
 .. _Server:
 .. autoclass:: Server
@@ -35,6 +35,8 @@ The hivemind.server module is organized as follows:
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
 
+.. currentmodule:: hivemind.moe.server.task_pool
+
 .. _TaskPool:
 .. autoclass:: TaskPool
     :members: submit_task, iterate_minibatches, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty

+ 1 - 1
docs/user/benchmarks.md

@@ -6,7 +6,7 @@ hivemind.
 ### Server throughput
 
 You can use [this benchmark](https://github.com/learning-at-home/hivemind/blob/master/benchmarks/benchmark_throughput.py) to
-check the performance impact of your changes to hivemind.client and server. The benchmark will start one server without
+check the performance impact of your changes to hivemind.moe. The benchmark will start one server without
 DHT with several experts, and then spawn trainer processes that load the server with requests. The two main statistics
 in this benchmark samples/s and startup time.
 

+ 1 - 1
docs/user/quickstart.md

@@ -24,7 +24,7 @@ You can also install it in the editable mode with `pip install -e .`.
 
 ## Host a server
 
-`hivemind.Server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for
+`hivemind.moe.Server` hosts one or several experts (PyTorch modules) for remote access. These experts are responsible for
 most of the model parameters and computation. The server can be started using either Python or
 [a shell script](https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_server.py). We'll use the shell
 for now. To host a server with default experts, run this in your shell:

+ 7 - 5
hivemind/__init__.py

@@ -1,8 +1,10 @@
-from hivemind.client import *
-from hivemind.dht import *
-from hivemind.p2p import *
-from hivemind.server import *
+from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.dht import DHT
+from hivemind.moe import ExpertBackend, Server, register_expert_class, RemoteExpert, RemoteMixtureOfExperts, \
+    RemoteSwitchMixtureOfExperts
+from hivemind.optim import CollaborativeAdaptiveOptimizer, DecentralizedOptimizerBase, CollaborativeOptimizer, \
+    DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
-from hivemind.optim import *
 
 __version__ = "0.9.10"

+ 2 - 0
hivemind/averaging/__init__.py

@@ -0,0 +1,2 @@
+from hivemind.averaging.averager import DecentralizedAverager
+from hivemind.averaging.training import TrainingAverager

+ 1 - 1
hivemind/client/averaging/allreduce.py → hivemind/averaging/allreduce.py

@@ -5,7 +5,7 @@ from enum import Enum
 import grpc
 import torch
 
-from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
+from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
 from hivemind.utils import Endpoint, get_logger, ChannelCache
 from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor

+ 12 - 12
hivemind/client/averaging/__init__.py → hivemind/averaging/averager.py

@@ -15,23 +15,23 @@ from dataclasses import asdict
 from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
 
 import grpc
-from grpc._cython.cygrpc import InternalError
-import torch
 import numpy as np
+import torch
+from grpc._cython.cygrpc import InternalError
 
+from hivemind.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.dht import DHT, DHTID
-from hivemind.client.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, AveragingMode
-from hivemind.client.averaging.load_balancing import load_balance_peers
-from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
-from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
-from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
 from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
-from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
+from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
+from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
 
 # flavour types
 StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
@@ -582,7 +582,7 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
         except BaseException as e:
             logger.debug(f"Averager background thread finished: {repr(e)}")
             break
-            
+
         if trigger == '_SHUTDOWN':
             break
 

+ 0 - 0
hivemind/client/averaging/group_info.py → hivemind/averaging/group_info.py


+ 1 - 1
hivemind/client/averaging/key_manager.py → hivemind/averaging/key_manager.py

@@ -6,7 +6,7 @@ from typing import Optional, List, Tuple
 import numpy as np
 
 from hivemind.dht import DHT
-from hivemind.client.averaging.group_info import GroupInfo
+from hivemind.averaging.group_info import GroupInfo
 from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
 
 GroupKey = str

+ 0 - 0
hivemind/client/averaging/load_balancing.py → hivemind/averaging/load_balancing.py


+ 6 - 6
hivemind/client/averaging/matchmaking.py → hivemind/averaging/matchmaking.py

@@ -12,14 +12,13 @@ import asyncio
 import grpc
 import grpc._cython.cygrpc
 
-from hivemind.client.averaging.group_info import GroupInfo
-from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
-from hivemind.dht import DHT, DHTID, DHTExpiration, get_dht_time
-from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
+from hivemind.dht import DHT, DHTID, DHTExpiration
+from hivemind.utils import get_logger, Endpoint, timed_storage, TimedStorage, get_dht_time
 from hivemind.proto import averaging_pb2, averaging_pb2_grpc
 from hivemind.utils.grpc import ChannelCache
 
-
 logger = get_logger(__name__)
 
 
@@ -391,7 +390,8 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (self.declared_expiration_time, self.endpoint):
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
+                    self.declared_expiration_time, self.endpoint):
                 await asyncio.wait({self.update_finished.wait(), self.declared_expiration.wait()},
                                    return_when=asyncio.FIRST_COMPLETED)
                 self.declared_expiration.clear()

+ 0 - 0
hivemind/client/averaging/partition.py → hivemind/averaging/partition.py


+ 2 - 2
hivemind/client/averaging/training.py → hivemind/averaging/training.py

@@ -1,13 +1,13 @@
 """ An extension of averager that supports common optimization use cases. """
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
 from threading import Lock
 from typing import Sequence, Dict, Iterator, Optional
-from contextlib import nullcontext
 
 import torch
 
-from hivemind.client.averaging import DecentralizedAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.utils import nested_flatten, nested_pack, get_logger
 
 logger = get_logger(__name__)

+ 0 - 5
hivemind/client/__init__.py

@@ -1,5 +0,0 @@
-from hivemind.client.expert import RemoteExpert
-from hivemind.client.moe import RemoteMixtureOfExperts
-from hivemind.client.switch_moe import RemoteSwitchMixtureOfExperts
-from hivemind.client.averaging import DecentralizedAverager
-from hivemind.client.averaging.training import TrainingAverager

+ 3 - 2
hivemind/dht/__init__.py

@@ -13,6 +13,7 @@ The code is organized as follows:
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
 from __future__ import annotations
+
 import asyncio
 import ctypes
 import multiprocessing as mp
@@ -21,11 +22,11 @@ from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from typing import Iterable, Optional, Sequence, Union, Callable, Awaitable, TypeVar
 
-from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
+from hivemind.dht.node import DHTNode, DHTID
 from hivemind.dht.routing import DHTValue, DHTKey, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, DHTExpiration
 from hivemind.utils.networking import Hostname, Endpoint, strip_port
-from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, get_dht_time
 
 logger = get_logger(__name__)
 

+ 8 - 6
hivemind/dht/node.py

@@ -12,10 +12,10 @@ from sortedcontainers import SortedSet
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, DHTExpiration
 from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
 
 logger = get_logger(__name__)
@@ -27,7 +27,7 @@ class DHTNode:
     Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
 
     :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
-     For example, expert heartbeat emitted by a hivemind.Server responsible for that expert.
+     For example, expert heartbeat emitted by a hivemind.moe.Server responsible for that expert.
      Such metadata does not require regular maintenance by peers or persistence on shutdown.
      Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
 
@@ -139,8 +139,8 @@ class DHTNode:
         self.cache_refresh_task = None
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 parallel_rpc, cache_size, listen, listen_on, endpoint, record_validator,
-                                                 **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, endpoint,
+                                                 record_validator, **kwargs)
         self.port = self.protocol.port
 
         if initial_peers:
@@ -361,7 +361,8 @@ class DHTNode:
         try:
             await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
             assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-            return {(key, subkey) if subkey is not None else key: status or False for (key, subkey), status in store_ok.items()}
+            return {(key, subkey) if subkey is not None else key: status or False
+                    for (key, subkey), status in store_ok.items()}
         except asyncio.CancelledError as e:
             store_task.cancel()
             raise e
@@ -711,6 +712,7 @@ class Blacklist:
     :param base_time: peers are suspended for this many seconds by default
     :param backoff_rate: suspension time increases by this factor after each successive failure
     """
+
     def __init__(self, base_time: float, backoff_rate: float, **kwargs):
         self.base_time, self.backoff = base_time, backoff_rate
         self.banned_peers = TimedStorage[Endpoint, int](**kwargs)

+ 4 - 4
hivemind/dht/protocol.py

@@ -2,18 +2,18 @@
 from __future__ import annotations
 
 import asyncio
-from itertools import zip_longest
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 
 import grpc
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
-from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
+from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
-from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
+from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, GRPC_KEEPALIVE_OPTIONS
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
+from hivemind.utils.timed_storage import DHTExpiration, get_dht_time, MAX_DHT_TIME_DISCREPANCY_SECONDS, \
+    ValueWithExpiration
 
 logger = get_logger(__name__)
 

+ 2 - 1
hivemind/dht/routing.py

@@ -8,7 +8,8 @@ import random
 from collections.abc import Iterable
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
-from hivemind.utils import Endpoint, MSGPackSerializer, get_dht_time, DHTExpiration
+
+from hivemind.utils import Endpoint, MSGPackSerializer, get_dht_time
 
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
 

+ 1 - 4
hivemind/dht/schema.py

@@ -1,17 +1,14 @@
-import binascii
 import re
-from contextlib import contextmanager
 from typing import Any, Dict, Optional, Type
 
 import pydantic
 
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.protocol import DHTProtocol
-from hivemind.dht.routing import DHTID, DHTKey
+from hivemind.dht.routing import DHTID
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase
 from hivemind.utils import get_logger
 
-
 logger = get_logger(__name__)
 
 

+ 4 - 4
hivemind/dht/storage.py

@@ -1,9 +1,10 @@
 from __future__ import annotations
+
 from typing import Optional, Union
 
-from hivemind.dht.routing import DHTID, DHTExpiration, BinaryDHTValue, Subkey
+from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey
 from hivemind.utils.serializer import MSGPackSerializer
-from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage
+from hivemind.utils.timed_storage import KeyType, ValueType, TimedStorage, DHTExpiration
 
 
 @MSGPackSerializer.ext_serializable(0x50)
@@ -32,6 +33,7 @@ class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]):
 
 class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]):
     """ A dictionary-like storage that can store binary values and/or nested dictionaries until expiration """
+
     def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration,
               subkey: Optional[Subkey] = None) -> bool:
         """
@@ -63,5 +65,3 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
             return previous_value.store(subkey, value, expiration_time)
         else:
             return False
-
-

+ 2 - 2
hivemind/hivemind_cli/run_server.py

@@ -5,10 +5,10 @@ import configargparse
 import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.server import Server
+from hivemind.moe.server import Server
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
-from hivemind.server.layers import schedule_name_to_scheduler
+from hivemind.moe.server.layers import schedule_name_to_scheduler
 
 logger = get_logger(__name__)
 

+ 2 - 0
hivemind/moe/__init__.py

@@ -0,0 +1,2 @@
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.server import ExpertBackend, Server, register_expert_class

+ 3 - 0
hivemind/moe/client/__init__.py

@@ -0,0 +1,3 @@
+from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.client.moe import RemoteMixtureOfExperts
+from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts

+ 4 - 4
hivemind/client/beam_search.py → hivemind/moe/client/beam_search.py

@@ -5,9 +5,9 @@ from functools import partial
 from typing import Sequence, Optional, List, Tuple, Dict, Deque, Union, Set, Iterator
 
 from hivemind.dht import DHT, DHTNode, DHTExpiration
-from hivemind.client.expert import RemoteExpert
-from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
-                                        PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
+from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, UidEndpoint, Score, Coordinate,
+                                            PREFIX_PATTERN, UID_DELIMITER, is_valid_prefix)
 from hivemind.utils import get_logger, get_dht_time, MPFuture
 
 logger = get_logger(__name__)
@@ -22,7 +22,7 @@ class MoEBeamSearcher:
         * optional prefix that determines expert role, experiment name, etc.
         * one or more integers that determine that expert's position in an N-dimensional grid
 
-    A hivemind.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
+    A hivemind.moe.Server can ``declare_experts(dht, expert_uids: List[str])`` to make its experts visible to everyone.
     When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
     For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
     ``"ffn_expert.98", "ffn_expert.98.76", "ffn_expert.98.76.54", ..., "ffn_expert.98.76.54.32.10"``

+ 0 - 0
hivemind/client/expert.py → hivemind/moe/client/expert.py


+ 3 - 3
hivemind/client/moe.py → hivemind/moe/client/moe.py

@@ -10,10 +10,10 @@ import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
 import hivemind
-from hivemind.client.beam_search import MoEBeamSearcher
-from hivemind.client.expert import RemoteExpert, DUMMY, _get_expert_stub
+from hivemind.moe.client.beam_search import MoEBeamSearcher
+from hivemind.moe.client.expert import RemoteExpert, DUMMY, _get_expert_stub
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.utils import nested_pack, nested_flatten, nested_map
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.logging import get_logger

+ 3 - 3
hivemind/client/switch_moe.py → hivemind/moe/client/switch_moe.py

@@ -5,9 +5,9 @@ from typing import Tuple, List
 import grpc
 import torch
 
-from hivemind.client.expert import RemoteExpert, DUMMY
-from hivemind.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
-from hivemind.server.expert_uid import UID_DELIMITER
+from hivemind.moe.client.expert import RemoteExpert, DUMMY
+from hivemind.moe.client.moe import RemoteMixtureOfExperts, _RemoteCallMany
+from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.utils import nested_pack, nested_flatten
 from hivemind.utils.logging import get_logger
 

+ 9 - 10
hivemind/server/__init__.py → hivemind/moe/server/__init__.py

@@ -12,15 +12,14 @@ import torch
 
 import hivemind
 from hivemind.dht import DHT
-from hivemind.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
-from hivemind.server.checkpoints import CheckpointSaver, load_experts, is_directory
-from hivemind.server.connection_handler import ConnectionHandler
-from hivemind.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
-from hivemind.server.expert_backend import ExpertBackend
-from hivemind.server.layers import name_to_block, name_to_input
-from hivemind.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
-from hivemind.server.runtime import Runtime
-from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
+from hivemind.moe.server.expert_uid import UID_DELIMITER, generate_uids_from_pattern
+from hivemind.moe.server.checkpoints import CheckpointSaver, load_experts, is_directory
+from hivemind.moe.server.connection_handler import ConnectionHandler
+from hivemind.moe.server.dht_handler import DHTHandlerThread, declare_experts, get_experts
+from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.moe.server.layers import name_to_block, name_to_input, register_expert_class
+from hivemind.moe.server.layers import add_custom_models_from_file, schedule_name_to_scheduler
+from hivemind.moe.server.runtime import Runtime
 from hivemind.utils import Endpoint, get_port, replace_port, find_open_port, get_logger, BatchTensorDescriptor
 from hivemind.proto.runtime_pb2 import CompressionType
 
@@ -86,7 +85,7 @@ class Server(threading.Thread):
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
            means "sample random experts between myprefix.0.0 and myprefix.255.255;
         :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
-        :param expert_cls: expert type from hivemind.server.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
+        :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param min_batch_size: total num examples in the same batch will be greater than this value

+ 1 - 1
hivemind/server/checkpoints.py → hivemind/moe/server/checkpoints.py

@@ -8,7 +8,7 @@ from typing import Dict
 
 import torch
 
-from hivemind.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/server/connection_handler.py → hivemind/moe/server/connection_handler.py

@@ -7,7 +7,7 @@ import grpc
 import torch
 
 from hivemind.proto import runtime_pb2, runtime_pb2_grpc as runtime_grpc
-from hivemind.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger, Endpoint, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor

+ 3 - 3
hivemind/server/dht_handler.py → hivemind/moe/server/dht_handler.py

@@ -3,9 +3,9 @@ from functools import partial
 from typing import Sequence, Dict, List, Tuple, Optional
 
 from hivemind.dht import DHT, DHTNode, DHTExpiration, DHTValue
-from hivemind.client.expert import RemoteExpert
-from hivemind.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, Coordinate,
-                                        UID_DELIMITER, UID_PATTERN, is_valid_uid, split_uid)
+from hivemind.moe.client.expert import RemoteExpert
+from hivemind.moe.server.expert_uid import (ExpertUID, ExpertPrefix, FLAT_EXPERT, Coordinate,
+                                            UID_DELIMITER, UID_PATTERN, is_valid_uid, split_uid)
 from hivemind.utils import Endpoint, get_dht_time, get_port
 
 

+ 1 - 1
hivemind/server/expert_backend.py → hivemind/moe/server/expert_backend.py

@@ -3,7 +3,7 @@ from typing import Dict, Sequence, Any, Tuple, Union, Callable
 import torch
 from torch import nn
 
-from hivemind.server.task_pool import TaskPool
+from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import BatchTensorDescriptor, DUMMY_BATCH_SIZE
 from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_flatten, nested_pack, nested_compare, nested_map

+ 1 - 1
hivemind/server/expert_uid.py → hivemind/moe/server/expert_uid.py

@@ -82,7 +82,7 @@ def generate_uids_from_pattern(num_experts: int, expert_pattern: Optional[str],
 
         # 2. look into DHT (if given) and remove duplicates
         if dht:
-            existing_expert_uids = {found_expert.uid for found_expert in hivemind.get_experts(dht, new_uids)
+            existing_expert_uids = {found_expert.uid for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
                                     if found_expert is not None}
             new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
 

+ 9 - 0
hivemind/moe/server/layers/__init__.py

@@ -0,0 +1,9 @@
+name_to_block = {}
+name_to_input = {}
+
+import hivemind.moe.server.layers.common
+import hivemind.moe.server.layers.dropout
+from hivemind.moe.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
+from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
+
+schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 1 - 1
hivemind/server/layers/common.py → hivemind/moe/server/layers/common.py

@@ -3,7 +3,7 @@ import time
 import torch
 from torch import nn as nn
 
-from hivemind.server.layers.custom_experts import register_expert_class
+from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 
 # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py

+ 1 - 1
hivemind/server/layers/custom_experts.py → hivemind/moe/server/layers/custom_experts.py

@@ -5,7 +5,7 @@ from typing import Callable, Type
 import torch
 import torch.nn as nn
 
-from hivemind.server.layers import name_to_block, name_to_input
+from hivemind.moe.server.layers import name_to_block, name_to_input
 
 
 def add_custom_models_from_file(path: str):

+ 1 - 1
hivemind/server/layers/dropout.py → hivemind/moe/server/layers/dropout.py

@@ -1,7 +1,7 @@
 import torch.autograd
 from torch import nn as nn
 
-from hivemind.server.layers.custom_experts import register_expert_class
+from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 
 class DeterministicDropoutFunction(torch.autograd.Function):

+ 0 - 0
hivemind/server/layers/lr_schedule.py → hivemind/moe/server/layers/lr_schedule.py


+ 1 - 1
hivemind/server/runtime.py → hivemind/moe/server/runtime.py

@@ -12,7 +12,7 @@ from typing import Dict, NamedTuple, Optional
 import torch
 from prefetch_generator import BackgroundGenerator
 
-from hivemind.server.expert_backend import ExpertBackend
+from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger
 
 logger = get_logger(__name__)

+ 0 - 0
hivemind/server/task_pool.py → hivemind/moe/server/task_pool.py


+ 1 - 2
hivemind/optim/__init__.py

@@ -1,5 +1,4 @@
+from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
-from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer

+ 0 - 2
hivemind/optim/base.py

@@ -1,5 +1,3 @@
-from typing import Any
-
 import torch
 
 from hivemind.dht import DHT

+ 3 - 4
hivemind/optim/collaborative.py

@@ -9,14 +9,13 @@ import numpy as np
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
-from hivemind.client.averaging.training import TrainingAverager
+from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.utils import Endpoint, ValueWithExpiration, get_dht_time, get_logger
-
+from hivemind.utils import Endpoint, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -115,7 +114,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         self.prefix, self.scheduler = prefix, scheduler
         self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
-        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period =\
+        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = \
             min_refresh_period, max_refresh_period, default_refresh_period
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
         self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration

+ 1 - 1
hivemind/optim/simple.py

@@ -5,7 +5,7 @@ from typing import Optional, Sequence, Tuple
 import torch
 
 from hivemind.dht import DHT
-from hivemind.client import TrainingAverager
+from hivemind.averaging import TrainingAverager
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time
 

+ 0 - 2
hivemind/p2p/p2p_daemon.py

@@ -1,6 +1,5 @@
 import asyncio
 import secrets
-from copy import deepcopy
 from dataclasses import dataclass
 from importlib.resources import path
 from subprocess import Popen
@@ -15,7 +14,6 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, St
 from hivemind.proto import p2pd_pb2
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import find_open_port
 
 logger = get_logger(__name__)
 

+ 0 - 9
hivemind/server/layers/__init__.py

@@ -1,9 +0,0 @@
-name_to_block = {}
-name_to_input = {}
-
-import hivemind.server.layers.common
-import hivemind.server.layers.dropout
-from hivemind.server.layers.custom_experts import add_custom_models_from_file, register_expert_class
-from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
-
-schedule_name_to_scheduler = {'linear': get_linear_schedule_with_warmup, 'none': None}

+ 2 - 7
hivemind/utils/auth.py

@@ -1,21 +1,16 @@
 import asyncio
 import functools
 import secrets
-import threading
-import time
 from abc import ABC, abstractmethod
-from enum import Enum
 from datetime import timedelta
-from typing import Optional, Tuple
-
-import grpc
+from enum import Enum
+from typing import Optional
 
 from hivemind.proto.auth_pb2 import AccessToken, RequestAuthInfo, ResponseAuthInfo
 from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
 from hivemind.utils.logging import get_logger
 from hivemind.utils.timed_storage import TimedStorage, get_dht_time
 
-
 logger = get_logger(__name__)
 
 

+ 0 - 1
hivemind/utils/crypto.py

@@ -3,7 +3,6 @@ from __future__ import annotations
 import base64
 import threading
 from abc import ABC, abstractmethod
-from typing import Optional
 
 from cryptography import exceptions
 from cryptography.hazmat.primitives import hashes, serialization

+ 1 - 1
requirements-dev.txt

@@ -5,4 +5,4 @@ pytest-cov
 codecov
 tqdm
 scikit-learn
-psutil
+psutil

+ 1 - 1
requirements-docs.txt

@@ -1,2 +1,2 @@
 recommonmark
-sphinx_rtd_theme
+sphinx_rtd_theme

+ 7 - 6
tests/test_allreduce.py

@@ -3,16 +3,16 @@ import random
 import time
 from typing import Sequence
 
+import grpc
 import pytest
 import torch
-import grpc
 
 from hivemind import aenumerate, Endpoint
-from hivemind.client.averaging.allreduce import AllReduceRunner, AveragingMode
-from hivemind.client.averaging.partition import TensorPartContainer, TensorPartReducer
-from hivemind.utils import deserialize_torch_tensor, ChannelCache
-from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.proto import averaging_pb2_grpc
+from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import deserialize_torch_tensor, ChannelCache
 
 
 @pytest.mark.forked
@@ -140,6 +140,7 @@ async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float
 
 class AllreduceRunnerForTesting(AllReduceRunner):
     """ a version of AllReduceRunner that was monkey-patched to accept custom endpoint names """
+
     def __init__(self, *args, peer_endpoints, **kwargs):
         self.__peer_endpoints = peer_endpoints
         super().__init__(*args, **kwargs)
@@ -162,7 +163,7 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
     ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0)),
     ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4)),
 ])
-@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19],)
+@pytest.mark.parametrize("part_size_bytes", [2 ** 20, 256, 19], )
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):

+ 38 - 33
tests/test_averaging.py

@@ -1,13 +1,14 @@
 import random
 
 import numpy as np
-import torch
 import pytest
+import torch
 
 import hivemind
-from hivemind.client.averaging.allreduce import AveragingMode
-from hivemind.client.averaging.load_balancing import load_balance_peers
-from hivemind.client.averaging.key_manager import GroupKeyManager
+import hivemind.averaging.averager
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.key_manager import GroupKeyManager
+from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
@@ -45,7 +46,8 @@ def _test_allreduce_once(n_clients, n_aux):
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
 
     n_peers = 4
-    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
+    modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (
+            n_peers - n_clients - n_aux)
     random.shuffle(modes)
 
     tensors1 = [torch.randn(123), torch.zeros(3)]
@@ -55,12 +57,14 @@ def _test_allreduce_once(n_clients, n_aux):
     peer_tensors = [tensors1, tensors2, tensors3, tensors4]
 
     reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
-                 if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
+                     if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
 
-    averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
-                                                auxiliary=mode == AveragingMode.AUX, start=True)
-                 for tensors, mode in zip(peer_tensors, modes)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
+                                                 prefix='mygroup', listen=mode != AveragingMode.CLIENT,
+                                                 listen_on='127.0.0.1:*',
+                                                 auxiliary=mode == AveragingMode.AUX, start=True)
+        for tensors, mode in zip(peer_tensors, modes)]
 
     futures = []
     for averager in averagers:
@@ -106,10 +110,10 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
-    averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', listen=listen, listen_on='127.0.0.1:*',
-                                                start=True)
-                 for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
+                                                 prefix='mygroup', listen=listen, listen_on='127.0.0.1:*', start=True)
+        for tensors, listen in zip([tensors1, tensors2, tensors3, tensors4], should_listen)]
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2]
                   + tensors4[i] * weights[3]) / sum(weights) for i in range(len(tensors1))]
@@ -142,12 +146,13 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        averager1 = hivemind.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
-                                                   compression_type=compression_type_pair, listen=False,
-                                                   target_group_size=2, prefix='mygroup', start=True)
-        averager2 = hivemind.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
-                                                   compression_type=compression_type_pair,
-                                                   target_group_size=2, prefix='mygroup', start=True)
+        averager1 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors1], dht=dht,
+                                                             compression_type=compression_type_pair,
+                                                             listen=False, target_group_size=2, prefix='mygroup',
+                                                             start=True)
+        averager2 = hivemind.averaging.DecentralizedAverager([x.clone() for x in tensors2], dht=dht,
+                                                             compression_type=compression_type_pair,
+                                                             target_group_size=2, prefix='mygroup', start=True)
 
         for future in averager1.step(wait=False), averager2.step(wait=False):
             future.result()
@@ -186,7 +191,7 @@ def compute_mean_std(averagers, unbiased=True):
 @pytest.mark.forked
 def test_allreduce_grid():
     dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
-    averagers = [hivemind.DecentralizedAverager(
+    averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         prefix='mygroup', initial_group_bits=bin(i // 2)[2:].rjust(2, '0'), start=True)
         for i in range(8)]
@@ -216,9 +221,9 @@ def test_allreduce_grid():
 @pytest.mark.forked
 def test_allgather():
     dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
-    averagers = [hivemind.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4, averaging_expiration=15,
-                                                prefix='mygroup', initial_group_bits='000', listen_on='127.0.0.1:*',
-                                                start=True)
+    averagers = [hivemind.averaging.DecentralizedAverager([torch.ones(1)], dht=dht, target_group_size=4,
+                                                          averaging_expiration=15, prefix='mygroup',
+                                                          initial_group_bits='000', listen_on='127.0.0.1:*', start=True)
                  for _ in range(8)]
 
     futures = []
@@ -286,7 +291,7 @@ def test_load_balancing():
 @pytest.mark.forked
 def test_too_few_peers():
     dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
-    averagers = [hivemind.DecentralizedAverager(
+    averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         averaging_expiration=1, request_timeout=0.5,
         prefix='mygroup', initial_group_bits=bin(i)[2:].rjust(3, '0'), start=True)
@@ -303,7 +308,7 @@ def test_too_few_peers():
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
     dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
-    averagers = [hivemind.DecentralizedAverager(
+    averagers = [hivemind.averaging.DecentralizedAverager(
         averaged_tensors=[torch.randn(3)], dht=dht, target_group_size=2,
         averaging_expiration=1, request_timeout=0.5,
         prefix='mygroup', initial_group_bits='', start=True)
@@ -323,7 +328,7 @@ def test_load_state_from_peers():
     super_metadata = dict(x=123)
     super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
 
-    class TestAverager(hivemind.DecentralizedAverager):
+    class TestAverager(hivemind.averaging.DecentralizedAverager):
         def get_current_state(self):
             """
             Get current state and send it to a peer. executed in the host process. Meant to be overriden.
@@ -373,8 +378,8 @@ def test_load_state_from_peers():
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True, endpoint='127.0.0.1:*')
-    averager = hivemind.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
-                                              prefix='test_prefix', target_group_size=2)
+    averager = hivemind.averaging.DecentralizedAverager([torch.randn(3)], dht=dht, start=True,
+                                                        prefix='test_prefix', target_group_size=2)
     averager.set_group_bits('00101011101010')
     assert averager.get_group_bits() == '00101011101010'
 
@@ -389,13 +394,13 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.client.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
-                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    averager1 = hivemind.averaging.TrainingAverager(opt1, average_gradients=True, average_parameters=True,
+                                                    average_opt_statistics=["exp_avg_sq"], **common_kwargs)
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.client.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
-                                                 average_opt_statistics=["exp_avg_sq"], **common_kwargs)
+    averager2 = hivemind.averaging.TrainingAverager(opt2, average_gradients=True, average_parameters=True,
+                                                    average_opt_statistics=["exp_avg_sq"], **common_kwargs)
     a = torch.ones(n_dims)
 
     for i in range(n_steps):

+ 2 - 1
tests/test_custom_experts.py

@@ -3,7 +3,8 @@ import os
 import pytest
 import torch
 
-from hivemind import RemoteExpert, background_server
+from hivemind import RemoteExpert
+from hivemind.moe.server import background_server
 
 CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), 'test_utils', 'custom_networks.py')
 

+ 6 - 5
tests/test_dht_crypto.py

@@ -5,9 +5,9 @@ import multiprocessing as mp
 import pytest
 
 import hivemind
-from hivemind.dht import get_dht_time
+from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.node import LOCALHOST
+from hivemind.dht.node import LOCALHOST, DHTNode
 from hivemind.dht.validation import DHTRecord
 from hivemind.utils.crypto import RSAPrivateKey
 
@@ -82,6 +82,7 @@ def get_signed_record(conn: mp.connection.Connection) -> DHTRecord:
     record = dataclasses.replace(record, value=validator.sign_value(record))
 
     conn.send(record)
+    return record
 
 
 def test_signing_in_different_process():
@@ -104,11 +105,11 @@ def test_signing_in_different_process():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_signatures():
-    alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
-    bob = await hivemind.DHTNode.create(
+    alice = await DHTNode.create(record_validator=RSASignatureValidator())
+    bob = await DHTNode.create(
         record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
-    mallory = await hivemind.DHTNode.create(
+    mallory = await DHTNode.create(
         record_validator=RSASignatureValidator(RSAPrivateKey()),
         initial_peers=[f"{LOCALHOST}:{alice.port}"])
 

+ 11 - 10
tests/test_dht_experts.py

@@ -6,10 +6,11 @@ import numpy as np
 import pytest
 
 import hivemind
-import hivemind.server.expert_uid
-from hivemind import LOCALHOST, declare_experts, get_experts
-from hivemind.client.beam_search import MoEBeamSearcher
-from hivemind.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
+from hivemind.dht import DHTNode
+from hivemind import LOCALHOST
+from hivemind.moe.client.beam_search import MoEBeamSearcher
+from hivemind.moe.server import declare_experts, get_experts
+from hivemind.moe.server.expert_uid import UidEndpoint, is_valid_uid, is_valid_prefix, split_uid
 
 
 @pytest.mark.forked
@@ -25,14 +26,14 @@ def test_store_get_experts():
     expert_uids = [f"my_expert.{i}" for i in range(50)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        hivemind.declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
+        declare_experts(first_peer, expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
 
     found = get_experts(other_peer, random.sample(expert_uids, 5) + ['foo', 'bar'])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
     other_expert, other_port = "my_other_expert.1337", random.randint(1000, 9999)
-    hivemind.declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
+    declare_experts(other_peer, [other_expert], f'that_host:{other_port}')
     first_notfound, first_found = get_experts(first_peer, ['foobar', other_expert])
     assert isinstance(first_found, hivemind.RemoteExpert)
     assert first_found.endpoint == f'that_host:{other_port}'
@@ -43,8 +44,8 @@ def test_store_get_experts():
     time.sleep(1.0)
     remaining_peer1 = random.choice([peer for peer in peers if peer.is_alive()])
     remaining_peer2 = random.choice([peer for peer in peers if peer.is_alive()])
-    assert all(hivemind.declare_experts(remaining_peer1, ['new_expert.1'], 'dummy'))
-    assert hivemind.get_experts(remaining_peer2, ['new_expert.1'])[0].endpoint == 'dummy'
+    assert all(declare_experts(remaining_peer1, ['new_expert.1'], 'dummy'))
+    assert get_experts(remaining_peer2, ['new_expert.1'])[0].endpoint == 'dummy'
 
 
 @pytest.mark.forked
@@ -149,7 +150,7 @@ async def test_negative_caching():
         peers.append(hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True))
 
     writer_peer = random.choice(peers)
-    assert all(hivemind.declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
+    assert all(declare_experts(writer_peer, ['ffn.1.2.3', 'ffn.3.4.5'], 'myaddr:1234').values())
 
     neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
     neg_caching_peer = hivemind.DHT(initial_peers=neighbors_i, cache_locally=False, start=True)
@@ -157,7 +158,7 @@ async def test_negative_caching():
     # get prefixes by the peer with negative caching. Cache "no data" entries for ffn.0.*, ffn.2.*, ffn.4.*, ffn.5.*
     assert len(beam_search.get_initial_beam(scores=[.1, .2, .3, .4, .5, .6], beam_size=3)) == 2
 
-    node = await hivemind.DHTNode.create(initial_peers=neighbors_i)
+    node = await DHTNode.create(initial_peers=neighbors_i)
     fetched = await asyncio.gather(*(node.get(f'ffn.{i}.') for i in range(10)))
     for i in range(6):
         assert fetched[i] is not None, f"node should have cached ffn.{i}."

+ 10 - 10
tests/test_dht_node.py

@@ -317,8 +317,8 @@ async def test_dhtnode_replicas():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_caching(T=0.05):
-    node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+    node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
+    node1 = await DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
                                           cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
     await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
@@ -366,7 +366,7 @@ async def test_dhtnode_reuse_get():
     peers = []
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+        peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
 
     await asyncio.gather(
         random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
@@ -396,10 +396,10 @@ async def test_dhtnode_reuse_get():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_blacklist():
-    node1 = await hivemind.DHTNode.create(blacklist_time=999)
-    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node1 = await DHTNode.create(blacklist_time=999)
+    node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node3 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node4 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
 
     assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
     assert len(node2.blacklist.ban_counter) == 0
@@ -428,9 +428,9 @@ async def test_dhtnode_blacklist():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
-    node1 = await hivemind.DHTNode.create(blacklist_time=999)
+    node1 = await DHTNode.create(blacklist_time=999)
     with pytest.raises(ValidationError):
-        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
+        node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
                                               endpoint=fake_endpoint)
 
 
@@ -440,7 +440,7 @@ async def test_dhtnode_edge_cases():
     peers = []
     for i in range(5):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
+        peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     keys = subkeys + [()]

+ 1 - 1
tests/test_dht_schema.py

@@ -3,7 +3,7 @@ from pydantic import BaseModel, StrictInt, conint
 from typing import Dict
 
 import hivemind
-from hivemind.dht import get_dht_time
+from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.node import DHTNode, LOCALHOST
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.validation import DHTRecord, RecordValidatorBase

+ 1 - 1
tests/test_dht_storage.py

@@ -1,6 +1,6 @@
 import time
 
-from hivemind.dht.routing import get_dht_time
+from hivemind.utils.timed_storage import get_dht_time
 from hivemind.dht.storage import DHTLocalStorage, DHTID, DictionaryDHTValue
 from hivemind.utils.serializer import MSGPackSerializer
 

+ 2 - 2
tests/test_expert_backend.py

@@ -6,8 +6,8 @@ import torch
 from torch.nn import Linear
 
 from hivemind import BatchTensorDescriptor, ExpertBackend
-from hivemind.server.checkpoints import store_experts, load_experts
-from hivemind.server.layers.lr_schedule import get_linear_schedule_with_warmup
+from hivemind.moe.server.checkpoints import store_experts, load_experts
+from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup
 
 EXPERT_WEIGHT_UPDATES = 3
 BACKWARD_PASSES_BEFORE_SAVE = 2

+ 7 - 7
tests/test_moe.py

@@ -4,9 +4,9 @@ import pytest
 import torch
 
 import hivemind
-from hivemind import background_server
-from hivemind.client.expert import DUMMY
-from hivemind.server import layers
+from hivemind.moe.server import background_server, declare_experts
+from hivemind.moe.client.expert import DUMMY
+from hivemind.moe.server import layers
 
 
 @pytest.mark.forked
@@ -60,7 +60,7 @@ def test_call_many(hidden_dim=16):
         e0, e1, e2, e3, e4 = [hivemind.RemoteExpert(f'expert.{i}', server_endpoint) for i in range(5)]
         e5 = hivemind.RemoteExpert(f'thisshouldnotexist', '127.0.0.1:80')
 
-        mask, expert_outputs = hivemind.client.moe._RemoteCallMany.apply(
+        mask, expert_outputs = hivemind.moe.client.moe._RemoteCallMany.apply(
             DUMMY, [[e0, e1, e2], [e2, e4], [e1, e5, e3], []], k_min, backward_k_min, timeout_after_k_min,
             forward_timeout, backward_timeout, detect_anomalies, allow_zero_outputs, e1.info, inputs
         )
@@ -120,7 +120,7 @@ def test_remote_module_call(hidden_dim=16):
 def test_beam_search_correctness():
     all_expert_uids = [f'ffn.{5 + i}.{10 + j}.{15 + k}' for i in range(10) for j in range(10) for k in range(10)]
     dht = hivemind.DHT(start=True)
-    assert all(hivemind.declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
+    assert all(declare_experts(dht, all_expert_uids, endpoint='fake-endpoint'))
 
     dmoe = hivemind.RemoteMixtureOfExperts(
         in_features=32, grid_size=(32, 32, 32), dht=dht, k_best=4, uid_prefix='ffn.')
@@ -168,7 +168,7 @@ def test_determinism(hidden_dim=16):
 def test_compute_expert_scores():
     try:
         dht = hivemind.DHT(start=True)
-        moe = hivemind.client.moe.RemoteMixtureOfExperts(
+        moe = hivemind.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=16, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert.')
         gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
@@ -209,7 +209,7 @@ def test_client_anomaly_detection():
     experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')
 
     dht = hivemind.DHT(start=True)
-    server = hivemind.Server(dht, experts, num_connection_handlers=1)
+    server = hivemind.moe.Server(dht, experts, num_connection_handlers=1)
     server.start()
     try:
         server.ready.wait()

+ 2 - 2
tests/test_p2p_daemon.py

@@ -10,11 +10,11 @@ import pytest
 import torch
 from multiaddr import Multiaddr
 
-from hivemind.p2p import P2P, P2PHandlerError, PeerID, PeerInfo
+from hivemind.p2p import P2P, P2PHandlerError
 from hivemind.proto import dht_pb2, runtime_pb2
 from hivemind.utils import MSGPackSerializer
-from hivemind.utils.networking import find_open_port
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.utils.networking import find_open_port
 
 
 def is_process_running(pid: int) -> bool:

+ 4 - 2
tests/test_training.py

@@ -7,8 +7,10 @@ import torch.nn as nn
 import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
-from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
-    DecentralizedSGD, DecentralizedAdam
+from hivemind import DHT
+from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
+from hivemind.moe.server import background_server
+from hivemind.optim import DecentralizedSGD, DecentralizedAdam
 
 
 @pytest.mark.forked

+ 3 - 3
tests/test_util_modules.py

@@ -4,17 +4,17 @@ import multiprocessing as mp
 import random
 import time
 
+import numpy as np
 import pytest
 import torch
-import numpy as np
 
 import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import MSGPackSerializer
-from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
+from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
 
@@ -439,7 +439,7 @@ def test_split_parts():
     for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
         with pytest.raises(RuntimeError):
             deserialize_torch_tensor(combined)
-            # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceRunner
+            # note: we rely on this being RuntimeError in hivemind.averaging.allreduce.AllreduceRunner
 
 
 def test_generic_data_classes():

+ 1 - 1
tests/test_utils/custom_networks.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from hivemind.server.layers.custom_experts import register_expert_class
+from hivemind.moe import register_expert_class
 
 sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim))