Эх сурвалжийг харах

massive refactors (#74)

* remove JoblibSerializer

* rename SharedFuture to MPFuture

* move DUMMY to hivemind.expert, remove hivemind.utils.data

* style: line length <130

* changed CheckpointSaver.dir => checkpoint dir (rationale: same name everywhere, avoid shadowing builtin dir)

* rename SharedFuture to MPFuture

* rename conn_handler_processes => num_connection_handlers (rationale: to make it similar to num_experts, num_workers, num_replicas, num_threads)

* remove hivemind.utils.data

* typo sucessfully => successfully

* rename SharedFuture to MPFuture

* move hivemind.runtime.* into hivemind.server

* tuple endpoint => string endpoint (#76)

* wip: passed test_dht

* wip: passed test_training

* wip: all tests should work now

* wip: benchmark_dht works now

* wip: benchmark_throughput works now

* DHT now accepts initial_peers same way as DHTNode

* fix initial_peers in test_training

* review by @mryab

* review: move import multiprocessing.* to files that require them

* review: remove unused import ctypes

* review: add todo

* review: inline strip_endpoint
justheuristic 5 жил өмнө
parent
commit
f496f2c14a

+ 1 - 4
docs/modules/server.rst

@@ -1,4 +1,4 @@
-``hivemind.server & runtime``
+**Hivemind Server**
 ========================================
 ========================================
 
 
 .. automodule:: hivemind.server
 .. automodule:: hivemind.server
@@ -9,13 +9,10 @@
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
 
 
-.. currentmodule:: hivemind.runtime
-
 .. autoclass:: Runtime
 .. autoclass:: Runtime
     :members:
     :members:
     :member-order: bysource
     :member-order: bysource
 
 
-
 .. autoclass:: ExpertBackend
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource
     :member-order: bysource

+ 1 - 2
hivemind/__init__.py

@@ -1,7 +1,6 @@
 from hivemind.client import *
 from hivemind.client import *
 from hivemind.dht import *
 from hivemind.dht import *
-from hivemind.server import Server
+from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
-from hivemind.runtime import *
 
 
 __version__ = '0.7.1'
 __version__ = '0.7.1'

+ 13 - 14
hivemind/client/expert.py

@@ -7,9 +7,11 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.utils import nested_flatten, DUMMY, nested_pack, nested_compare
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
 
 
+DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
+
 
 
 class RemoteExpert(nn.Module):
 class RemoteExpert(nn.Module):
     """
     """
@@ -20,20 +22,18 @@ class RemoteExpert(nn.Module):
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
 
     :param uid: unique expert identifier
     :param uid: unique expert identifier
-    :param host: hostname where server operates
-    :param port: port to which server listens
+    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
     """
     """
 
 
-    def __init__(self, uid, host='127.0.0.1', port=8080):
+    def __init__(self, uid, endpoint: Endpoint):
         super().__init__()
         super().__init__()
-        self.uid, self.host, self.port = uid, host, port
-        self._channel, self._stub = None, None
-        self._info = None
+        self.uid, self.endpoint = uid, endpoint
+        self._channel, self._stub, self._info = None, None, None
 
 
     @property
     @property
     def stub(self):
     def stub(self):
         if self._channel is None:
         if self._channel is None:
-            self._channel = grpc.insecure_channel(f'{self.host}:{self.port}', options=[
+            self._channel = grpc.insecure_channel(self.endpoint, options=[
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_receive_message_length', -1)
                 ('grpc.max_receive_message_length', -1)
             ])
             ])
@@ -57,8 +57,7 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, self.stub,
-                                               *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
 
@@ -70,18 +69,18 @@ class RemoteExpert(nn.Module):
         return self._info
         return self._info
 
 
     def extra_repr(self):
     def extra_repr(self):
-        return f"uid={self.uid}, host={self.host}, port={self.port}"
+        return f"uid={self.uid}, endpoint={self.endpoint}"
 
 
 
 
 class _RemoteModuleCall(torch.autograd.Function):
 class _RemoteModuleCall(torch.autograd.Function):
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
 
 
     @staticmethod
     @staticmethod
-    def forward(ctx, dummy: torch.Tensor, uid: str, host: str, port: int, stub: runtime_grpc.ConnectionHandlerStub,
+    def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
                 *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
                 *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.host, ctx.port, ctx.stub = uid, host, port, stub
+        ctx.uid, ctx.stub = uid, stub
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
 
 
         outputs = stub.forward(
         outputs = stub.forward(
@@ -100,4 +99,4 @@ class _RemoteModuleCall(torch.autograd.Function):
             runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
             runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
 
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, None, None, *deserialized_grad_inputs)
+        return (DUMMY, None, None, *deserialized_grad_inputs)

+ 6 - 6
hivemind/client/moe.py

@@ -6,8 +6,8 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
-from hivemind.client.expert import RemoteExpert, _RemoteModuleCall
-from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background, \
+from hivemind.client.expert import RemoteExpert, _RemoteModuleCall, DUMMY
+from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, run_in_background, \
     run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
     run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
 
 
@@ -43,7 +43,8 @@ class RemoteMixtureOfExperts(nn.Module):
         self.dht, self.grid_size = dht, grid_size
         self.dht, self.grid_size = dht, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
-        self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
+        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
+        self.timeout_after_k_min = timeout_after_k_min
         self.allow_broadcasting = allow_broadcasting
         self.allow_broadcasting = allow_broadcasting
 
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
@@ -258,11 +259,10 @@ class _RemoteMoECall(torch.autograd.Function):
     @staticmethod
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, expert.stub,
-                                    *nested_flatten((args, kwargs)))
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.stub, *nested_flatten((args, kwargs)))
 
 
     @staticmethod
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, no_grad_stub, *grad_inputs = backward_result
+        grad_dummy, no_grad_uid, no_grad_stub, *grad_inputs = backward_result
         return grad_inputs
         return grad_inputs

+ 20 - 21
hivemind/dht/__init__.py

@@ -16,21 +16,21 @@ import asyncio
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
-from typing import List, Optional
+from typing import List, Optional, Sequence
 
 
 import uvloop
 import uvloop
 
 
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import SharedFuture, Endpoint, run_in_background
+from hivemind.utils import MPFuture, Endpoint, run_in_background
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):
     """
     """
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
 
 
-    :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
+    :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
@@ -42,12 +42,12 @@ class DHT(mp.Process):
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
     make_key = "{}::{}".format
 
 
-    def __init__(self, *initial_peers: Endpoint, listen_on: Endpoint = "0.0.0.0:*", start: bool, daemon: bool = True,
-                 max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
+    def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
+                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
         super().__init__()
         super().__init__()
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
-        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after server starts
+        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self.node: Optional[DHTNode] = None  # initialized inside self.run only
         self.node: Optional[DHTNode] = None  # initialized inside self.run only
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.ready = mp.Event()
@@ -99,11 +99,11 @@ class DHT(mp.Process):
         :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
         :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
         :returns: a list of [RemoteExpert if found else None]
         :returns: a list of [RemoteExpert if found else None]
         """
         """
-        future, _future = SharedFuture.make_pair()
+        future, _future = MPFuture.make_pair()
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
         return future.result()
         return future.result()
 
 
-    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
+    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
         expiration = expiration or get_dht_time()
         expiration = expiration or get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
@@ -114,41 +114,40 @@ class DHT(mp.Process):
 
 
         experts: List[Optional[RemoteExpert]] = [None] * len(uids)
         experts: List[Optional[RemoteExpert]] = [None] * len(uids)
         for i, (key, uid) in enumerate(zip(keys, uids)):
         for i, (key, uid) in enumerate(zip(keys, uids)):
-            maybe_result, maybe_expiration = response[key]
+            maybe_endpoint, maybe_expiration = response[key]
             if maybe_expiration is not None:  # if we found a value
             if maybe_expiration is not None:  # if we found a value
-                experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
+                experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
 
 
         future.set_result(experts)
         future.set_result(experts)
 
 
-    def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
+    def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         """
-        Make experts available to DHT; update timestamps if already available
+        Make experts visible to all DHT peers; update timestamps if declared previously.
 
 
         :param uids: a list of expert ids to update
         :param uids: a list of expert ids to update
-        :param addr: hostname that can be used to call this expert
-        :param port: port that can be used to call this expert
+        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
         :param timeout: waits for the procedure to finish, None means wait indeninitely
         :param timeout: waits for the procedure to finish, None means wait indeninitely
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         """
         """
-        future, _future = SharedFuture.make_pair() if wait else (None, None)
-        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, future=_future)))
+        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+        future, _future = MPFuture.make_pair() if wait else (None, None)
+        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
         if wait:
         if wait:
             return future.result(timeout)
             return future.result(timeout)
 
 
-    def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
+    def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         assert self.node is not None, "This method should only be accessed from inside .run method"
         assert self.node is not None, "This method should only be accessed from inside .run method"
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
         expiration_time = get_dht_time() + self.EXPIRATION
         expiration_time = get_dht_time() + self.EXPIRATION
         unique_prefixes = set()
         unique_prefixes = set()
-        coroutines = []
 
 
         keys, values = [], []
         keys, values = [], []
         for uid in uids:
         for uid in uids:
             uid_parts = uid.split(self.UID_DELIMETER)
             uid_parts = uid.split(self.UID_DELIMETER)
             keys.append(self.make_key('expert', uid))
             keys.append(self.make_key('expert', uid))
-            values.append((addr, port))
+            values.append(endpoint)
             unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
             unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
 
 
         for prefix in unique_prefixes:
         for prefix in unique_prefixes:
@@ -171,12 +170,12 @@ class DHT(mp.Process):
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         """
         """
         assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
         assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
-        future, _future = SharedFuture.make_pair()
+        future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
         self.pipe.send(('_first_k_active', [],
                         dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
                         dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
         return future.result()
         return future.result()
 
 
-    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: SharedFuture):
+    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
         assert self.node is not None, "This method should only be accessed from inside .run method"
         assert self.node is not None, "This method should only be accessed from inside .run method"
         max_prefetch = max_prefetch or len(prefixes)
         max_prefetch = max_prefetch or len(prefixes)
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()

+ 6 - 6
hivemind/dht/node.py

@@ -314,7 +314,7 @@ class DHTNode:
 
 
         # search metadata
         # search metadata
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
-        node_to_addr: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
+        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
 
 
         SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
         SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
@@ -331,18 +331,18 @@ class DHTNode:
 
 
         # stage 2: traverse the DHT for any unfinished keys
         # stage 2: traverse the DHT for any unfinished keys
         for key_id in unfinished_key_ids:
         for key_id in unfinished_key_ids:
-            node_to_addr.update(self.protocol.routing_table.get_nearest_neighbors(
+            node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
 
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
             queries = list(queries)
             queries = list(queries)
-            response = await self.protocol.call_find(node_to_addr[peer], queries)
+            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
             if not response:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
             for key_id, (maybe_value, maybe_expiration, peers) in response.items():
             for key_id, (maybe_value, maybe_expiration, peers) in response.items():
-                node_to_addr.update(peers)
+                node_to_endpoint.update(peers)
                 if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
                 if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
                     latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
                     latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
                 should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
                 should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
@@ -350,7 +350,7 @@ class DHTNode:
             return output
             return output
 
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_addr),
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
 
 
@@ -367,7 +367,7 @@ class DHTNode:
                     if node_id == latest_node_id:
                     if node_id == latest_node_id:
                         continue
                         continue
                     asyncio.create_task(self.protocol.call_store(
                     asyncio.create_task(self.protocol.call_store(
-                        node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
+                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
                     num_cached_nodes += 1
                     num_cached_nodes += 1
                     if num_cached_nodes >= self.cache_nearest:
                     if num_cached_nodes >= self.cache_nearest:
                         break
                         break

+ 5 - 7
hivemind/dht/protocol.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import heapq
 import heapq
 import os
 import os
-import urllib.parse
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 from warnings import warn
 
 
@@ -12,7 +11,7 @@ import grpc
 import grpc.experimental.aio
 import grpc.experimental.aio
 
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from hivemind.utils import Endpoint, compile_grpc, get_logger
+from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -40,7 +39,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         See DHTNode (node.py) for a more detailed description.
         See DHTNode (node.py) for a more detailed description.
 
 
         :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
-         for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
+         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
          Only the call_* methods are meant to be called publicly, e.g. from DHTNode
          Only the call_* methods are meant to be called publicly, e.g. from DHTNode
          Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
          Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
         """
         """
@@ -109,9 +108,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ Some node wants us to add it to our routing table. """
         """ Some node wants us to add it to our routing table. """
         if peer_info.node_id and peer_info.rpc_port:
         if peer_info.node_id and peer_info.rpc_port:
             sender_id = DHTID.from_bytes(peer_info.node_id)
             sender_id = DHTID.from_bytes(peer_info.node_id)
-            peer_url = urllib.parse.urlparse(context.peer())
-            address = peer_url.path[:peer_url.path.rindex(':')]
-            asyncio.create_task(self.update_routing_table(sender_id, f"{address}:{peer_info.rpc_port}"))
+            rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port)
+            asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
         return self.node_info
         return self.node_info
 
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
@@ -193,7 +191,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 output[key] = (value, expiration, nearest)
                 output[key] = (value, expiration, nearest)
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:

+ 10 - 10
hivemind/dht/routing.py

@@ -47,7 +47,7 @@ class RoutingTable:
 
 
     def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
     def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
         """
-        Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
+        Update routing table after an incoming request from :endpoint: or outgoing request to :endpoint:
 
 
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
         :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
@@ -91,7 +91,7 @@ class RoutingTable:
         """ Find endpoint for a given DHTID or vice versa """
         """ Find endpoint for a given DHTID or vice versa """
         return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
         return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
 
 
-    def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
+    def __setitem__(self, node_id: DHTID, endpoint: Endpoint) -> NotImplementedError:
         raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
         raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
 
 
     def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
     def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
@@ -160,7 +160,7 @@ class RoutingTable:
 class KBucket:
 class KBucket:
     """
     """
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
-    Maps DHT node ids to their endpoints (hostname, addr)
+    Maps DHT node ids to their endpoints
     """
     """
 
 
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
@@ -175,13 +175,13 @@ class KBucket:
         """ Check if node_id is between this bucket's lower and upper bounds """
         """ Check if node_id is between this bucket's lower and upper bounds """
         return self.lower <= node_id < self.upper
         return self.lower <= node_id < self.upper
 
 
-    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> bool:
+    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> bool:
         """
         """
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
 
 
         :param node_id: dht node identifier that should be added or moved to the front of bucket
         :param node_id: dht node identifier that should be added or moved to the front of bucket
-        :param addr: a pair of (hostname, port) associated with that node id
+        :param endpoint: network address associated with that node id
         :note: this function has a side-effect of resetting KBucket.last_updated time
         :note: this function has a side-effect of resetting KBucket.last_updated time
         """
         """
         if node_id in self.nodes_requested_for_ping:
         if node_id in self.nodes_requested_for_ping:
@@ -189,13 +189,13 @@ class KBucket:
         self.last_updated = get_dht_time()
         self.last_updated = get_dht_time()
         if node_id in self.nodes_to_endpoint:
         if node_id in self.nodes_to_endpoint:
             del self.nodes_to_endpoint[node_id]
             del self.nodes_to_endpoint[node_id]
-            self.nodes_to_endpoint[node_id] = addr
+            self.nodes_to_endpoint[node_id] = endpoint
         elif len(self.nodes_to_endpoint) < self.size:
         elif len(self.nodes_to_endpoint) < self.size:
-            self.nodes_to_endpoint[node_id] = addr
+            self.nodes_to_endpoint[node_id] = endpoint
         else:
         else:
             if node_id in self.replacement_nodes:
             if node_id in self.replacement_nodes:
                 del self.replacement_nodes[node_id]
                 del self.replacement_nodes[node_id]
-            self.replacement_nodes[node_id] = addr
+            self.replacement_nodes[node_id] = endpoint
             return False
             return False
         return True
         return True
 
 
@@ -229,9 +229,9 @@ class KBucket:
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
-        for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
+        for node_id, endpoint in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
             bucket = left if int(node_id) <= midpoint else right
             bucket = left if int(node_id) <= midpoint else right
-            bucket.add_or_update_node(node_id, addr)
+            bucket.add_or_update_node(node_id, endpoint)
         return left, right
         return left, right
 
 
     def __repr__(self):
     def __repr__(self):

+ 18 - 12
hivemind/server/__init__.py

@@ -1,12 +1,16 @@
 import multiprocessing as mp
 import multiprocessing as mp
+import multiprocessing.synchronize
 import threading
 import threading
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.runtime import Runtime, ExpertBackend
+from hivemind.server.runtime import Runtime
+from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.checkpoint_saver import CheckpointSaver
 from hivemind.server.checkpoint_saver import CheckpointSaver
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
 from hivemind.server.dht_handler import DHTHandlerThread
+from hivemind.utils import Endpoint, get_port, replace_port, find_open_port
 
 
 
 
 class Server(threading.Thread):
 class Server(threading.Thread):
@@ -20,11 +24,10 @@ class Server(threading.Thread):
      - follows orders from HivemindController - if it exists
      - follows orders from HivemindController - if it exists
 
 
     :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
     :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, host=IPADDR, port=PORT).
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param addr: server's dht address that determines how it can be accessed. Default is local connections only.
-    :param port: port to which server listens for requests such as expert forward or backward pass.
-    :param conn_handler_processes: maximum number of simultaneous requests. Please note that the default value of 1
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
         if dht is None, this parameter is ignored.
         if dht is None, this parameter is ignored.
@@ -32,13 +35,16 @@ class Server(threading.Thread):
         is ready (see .ready below)
         is ready (see .ready below)
     """
     """
 
 
-    def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
-                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
+    def __init__(
+            self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], listen_on: Endpoint = "0.0.0.0:*",
+            num_connection_handlers: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        self.addr, self.port = addr, port
-        self.conn_handlers = [ConnectionHandler(f"{self.addr}:{port}", self.experts)
-                              for _ in range(conn_handler_processes)]
+        if get_port(listen_on) is None:
+            self.listen_on = listen_on = replace_port(listen_on, new_port=find_open_port())
+        self.port = get_port(listen_on)
+
+        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
         else:
@@ -57,8 +63,8 @@ class Server(threading.Thread):
             if not self.dht.is_alive():
             if not self.dht.is_alive():
                 self.dht.run_in_background(await_ready=True)
                 self.dht.run_in_background(await_ready=True)
 
 
-            dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht,
-                                                  addr=self.addr, port=self.port, update_period=self.update_period)
+            dht_handler_thread = DHTHandlerThread(
+                experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
             dht_handler_thread.start()
             dht_handler_thread.start()
         if self.checkpoint_saver is not None:
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()
             self.checkpoint_saver.start()

+ 4 - 4
hivemind/server/checkpoint_saver.py

@@ -8,20 +8,20 @@ from typing import Dict
 
 
 import torch
 import torch
 
 
-from hivemind.runtime import ExpertBackend
+from hivemind.server.expert_backend import ExpertBackend
 
 
 
 
 class CheckpointSaver(threading.Thread):
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], dir: Path, update_period: int):
+    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
         super().__init__()
         self.expert_backends = expert_backends
         self.expert_backends = expert_backends
         self.update_period = update_period
         self.update_period = update_period
-        self.dir = dir
+        self.checkpoint_dir = checkpoint_dir
         self.stop = False
         self.stop = False
 
 
     def run(self) -> None:
     def run(self) -> None:
         while not self.stop:
         while not self.stop:
-            store_experts(self.expert_backends, self.dir)
+            store_experts(self.expert_backends, self.checkpoint_dir)
             time.sleep(self.update_period)
             time.sleep(self.update_period)
 
 
 
 

+ 1 - 1
hivemind/server/connection_handler.py

@@ -8,7 +8,7 @@ import grpc.experimental.aio
 import torch
 import torch
 import uvloop
 import uvloop
 
 
-from hivemind.runtime.expert_backend import ExpertBackend
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
 from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 5 - 4
hivemind/server/dht_handler.py

@@ -2,13 +2,14 @@ import threading
 import time
 import time
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
+from hivemind.utils import Endpoint, get_port
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
+    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5):
         super(DHTHandlerThread, self).__init__()
         super(DHTHandlerThread, self).__init__()
-        self.port = port
-        self.addr = addr
+        assert get_port(endpoint) is not None
+        self.endpoint = endpoint
         self.experts = experts
         self.experts = experts
         self.dht = dht
         self.dht = dht
         self.update_period = update_period
         self.update_period = update_period
@@ -16,5 +17,5 @@ class DHTHandlerThread(threading.Thread):
 
 
     def run(self) -> None:
     def run(self) -> None:
         while not self.stop:
         while not self.stop:
-            self.dht.declare_experts(self.experts.keys(), self.addr, self.port)
+            self.dht.declare_experts(self.experts.keys(), self.endpoint)
             time.sleep(self.update_period)
             time.sleep(self.update_period)

+ 10 - 7
hivemind/runtime/expert_backend.py → hivemind/server/expert_backend.py

@@ -3,8 +3,9 @@ from typing import Dict, Sequence, Any, Tuple, Union
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
-from hivemind.runtime.task_pool import TaskPool
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
+from hivemind.server.task_pool import TaskPool
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map,\
+    BatchTensorDescriptor, DUMMY_BATCH_SIZE
 
 
 
 
 class ExpertBackend(nn.Module):
 class ExpertBackend(nn.Module):
@@ -18,7 +19,7 @@ class ExpertBackend(nn.Module):
 
 
     :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
     :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
 
-     - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
+     - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
      - We recommend using experts that are ~invariant to the order in which they process batches
      - We recommend using experts that are ~invariant to the order in which they process batches
      - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
      - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
@@ -95,15 +96,17 @@ class ExpertBackend(nn.Module):
 
 
            .. todo correct state handling (see forward)
            .. todo correct state handling (see forward)
 
 
-           Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
+           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
         """
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
 
         with torch.enable_grad():
         with torch.enable_grad():
             args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
             args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
                     else tensor.detach() for tensor in args]
                     else tensor.detach() for tensor in args]
-            kwargs = {input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
-                                  else tensor.detach()) for input_key, tensor in kwargs.items()}
+            kwargs = {input_key: (tensor.detach().requires_grad_(True)
+                                  if tensor.dtype in (torch.half, torch.float, torch.double)
+                                  else tensor.detach())
+                      for input_key, tensor in kwargs.items()}
 
 
             outputs = self.expert(*args, **kwargs)
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
@@ -122,7 +125,7 @@ class ExpertBackend(nn.Module):
 
 
     def apply_gradients(self) -> None:
     def apply_gradients(self) -> None:
         """
         """
-        Train the expert for a single step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
         """
         self.opt.step()
         self.opt.step()
         self.opt.zero_grad()
         self.opt.zero_grad()

+ 2 - 2
hivemind/runtime/__init__.py → hivemind/server/runtime.py

@@ -1,4 +1,5 @@
 import multiprocessing as mp
 import multiprocessing as mp
+import multiprocessing.pool
 import threading
 import threading
 from itertools import chain
 from itertools import chain
 from selectors import DefaultSelector, EVENT_READ
 from selectors import DefaultSelector, EVENT_READ
@@ -7,8 +8,7 @@ from typing import Dict
 import torch
 import torch
 from prefetch_generator import BackgroundGenerator
 from prefetch_generator import BackgroundGenerator
 
 
-from hivemind.runtime.expert_backend import ExpertBackend
-from hivemind.runtime.task_pool import TaskPool, TaskPoolBase
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 5 - 4
hivemind/runtime/task_pool.py → hivemind/server/task_pool.py

@@ -3,6 +3,7 @@ Task pool is responsible for receiving tasks and grouping them together for proc
 """
 """
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
+import multiprocessing.context
 import os
 import os
 import threading
 import threading
 import time
 import time
@@ -14,7 +15,7 @@ from typing import List, Tuple, Dict, Any, Generator
 
 
 import torch
 import torch
 
 
-from hivemind.utils import SharedFuture, get_logger
+from hivemind.utils import MPFuture, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))
@@ -86,10 +87,10 @@ class TaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
         """ Add task to this pool's queue, return Future for its output """
-        future1, future2 = SharedFuture.make_pair()
+        future1, future2 = MPFuture.make_pair()
         task = Task(future1, args)
         task = Task(future1, args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
-            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it will never be finished")
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             future2.set_exception(exc)
             future2.set_exception(exc)
         else:
         else:
             self.tasks.put(task)
             self.tasks.put(task)
@@ -127,7 +128,7 @@ class TaskPool(TaskPoolBase):
     def run(self, *args, **kwargs):
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
         torch.set_num_threads(1)
         logger.info(f'{self.uid} starting, pid={os.getpid()}')
         logger.info(f'{self.uid} starting, pid={os.getpid()}')
-        pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
+        pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
                                          name=f'{self.uid}_output')
                                          name=f'{self.uid}_output')
         try:
         try:

+ 2 - 3
hivemind/utils/__init__.py

@@ -1,9 +1,8 @@
-from hivemind.utils.connection import *
-from hivemind.utils.data import *
+from hivemind.utils.networking import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.tensor_descr import *
 from hivemind.utils.tensor_descr import *
 from hivemind.utils.serializer import *
 from hivemind.utils.serializer import *
-from hivemind.utils.shared_future import *
+from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
 from hivemind.utils.threading import *
 from hivemind.utils.autograd import *
 from hivemind.utils.autograd import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *

+ 2 - 1
hivemind/utils/autograd.py

@@ -92,7 +92,8 @@ class _ParallelApplyFunction(torch.autograd.Function):
     @staticmethod
     @staticmethod
     def backward(ctx, *grad_outputs_flat: torch.Tensor):
     def backward(ctx, *grad_outputs_flat: torch.Tensor):
         func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
         func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
-        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
+        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]]
+                                 for i in range(len(contexts))]
         futures = [run_in_background(run_isolated_backward, func, context, *grads)
         futures = [run_in_background(run_isolated_backward, func, context, *grads)
                    for context, grads in zip(contexts, grad_outputs_per_call)]
                    for context, grads in zip(contexts, grad_outputs_per_call)]
         flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())
         flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())

+ 0 - 3
hivemind/utils/data.py

@@ -1,3 +0,0 @@
-import torch
-
-DUMMY = torch.empty(0, requires_grad=True)

+ 2 - 1
hivemind/utils/grpc.py

@@ -47,7 +47,8 @@ def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
 
 
 
 
-with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'server', 'connection_handler.proto')) as f_proto:
+with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+                       'server', 'connection_handler.proto')) as f_proto:
     runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
     runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
 
 
 
 

+ 2 - 2
hivemind/utils/logging.py

@@ -8,8 +8,8 @@ def get_logger(module_name: str) -> logging.Logger:
     loglevel = os.getenv('LOGLEVEL', 'INFO')
     loglevel = os.getenv('LOGLEVEL', 'INFO')
 
 
     logging.addLevelName(logging.WARNING, 'WARN')
     logging.addLevelName(logging.WARNING, 'WARN')
-    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}', style='{',
-                                  datefmt='%Y/%m/%d %H:%M:%S')
+    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}',
+                                  style='{', datefmt='%Y/%m/%d %H:%M:%S')
     handler = logging.StreamHandler()
     handler = logging.StreamHandler()
     handler.setFormatter(formatter)
     handler.setFormatter(formatter)
     logger = logging.getLogger(name_without_prefix)
     logger = logging.getLogger(name_without_prefix)

+ 1 - 1
hivemind/utils/shared_future.py → hivemind/utils/mpfuture.py

@@ -5,7 +5,7 @@ from warnings import warn
 import asyncio
 import asyncio
 
 
 
 
-class SharedFuture(Future):
+class MPFuture(Future):
     """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
     """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
     STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
     STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
     STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES
     STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES

+ 16 - 0
hivemind/utils/connection.py → hivemind/utils/networking.py

@@ -1,11 +1,27 @@
 import socket
 import socket
+import urllib.parse
 from contextlib import closing
 from contextlib import closing
+from typing import Optional
 
 
 Hostname, Port = str, int  # flavour types
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 LOCALHOST = '127.0.0.1'
 
 
 
 
+def get_port(endpoint: Endpoint) -> Optional[Port]:
+    """ get port or None if port is undefined """
+    # TODO: find a standard way to get port, make sure it works in malformed ports
+    try:
+        return int(endpoint[endpoint.rindex(':') + 1:], base=10)
+    except ValueError:  # :* or not specified
+        return None
+
+
+def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
+    assert endpoint.endswith(':*') or get_port(endpoint) is not None, endpoint
+    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
+
+
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:
     try:

+ 0 - 17
hivemind/utils/serializer.py

@@ -2,11 +2,9 @@
 import pickle
 import pickle
 from io import BytesIO
 from io import BytesIO
 
 
-import joblib
 import torch
 import torch
 import umsgpack
 import umsgpack
 
 
-
 class SerializerBase:
 class SerializerBase:
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
@@ -17,19 +15,6 @@ class SerializerBase:
         raise NotImplementedError()
         raise NotImplementedError()
 
 
 
 
-class JoblibSerializer(SerializerBase):
-
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        s = BytesIO()
-        joblib.dump(obj, s)
-        return s.getvalue()
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return joblib.load(BytesIO(buf))
-
-
 class PickleSerializer(SerializerBase):
 class PickleSerializer(SerializerBase):
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
@@ -41,7 +26,6 @@ class PickleSerializer(SerializerBase):
 
 
 
 
 class PytorchSerializer(SerializerBase):
 class PytorchSerializer(SerializerBase):
-
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
         s = BytesIO()
         s = BytesIO()
@@ -54,7 +38,6 @@ class PytorchSerializer(SerializerBase):
 
 
 
 
 class MSGPackSerializer(SerializerBase):
 class MSGPackSerializer(SerializerBase):
-
     @staticmethod
     @staticmethod
     def dumps(obj: object) -> bytes:
     def dumps(obj: object) -> bytes:
         return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158
         return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158

+ 2 - 1
hivemind/utils/tensor_descr.py

@@ -45,7 +45,8 @@ class BatchTensorDescriptor(TensorDescriptor):
     @classmethod
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
     def from_tensor(cls, tensor: torch.Tensor):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
-                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=torch.cuda.is_available() and tensor.is_pinned())
+                   device=tensor.device, requires_grad=tensor.requires_grad,
+                   pin_memory=torch.cuda.is_available() and tensor.is_pinned())
 
 
     def make_empty(self, batch_size, **kwargs):
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 0 - 1
requirements.txt

@@ -1,5 +1,4 @@
 torch>=1.3.0
 torch>=1.3.0
-joblib>=0.13
 numpy>=1.17
 numpy>=1.17
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 umsgpack
 umsgpack

+ 6 - 6
tests/benchmark_dht.py

@@ -9,9 +9,9 @@ from tqdm import trange
 from test_utils import increase_file_limit
 from test_utils import increase_file_limit
 
 
 
 
-def random_endpoint() -> Tuple[str, int]:
-    return (f"{random.randint(0, 256)}.{random.randint(0, 256)}."
-            f"{random.randint(0, 256)}.{random.randint(0, 256)}", random.randint(0, 65535))
+def random_endpoint() -> hivemind.Endpoint:
+    return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
+           f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
 
 
 
 
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
@@ -23,7 +23,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     peers = []
     peers = []
     for _ in trange(num_peers):
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peer = hivemind.DHT(*neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
+        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
         peers.append(peer)
         peers.append(peer)
 
 
     store_peer, get_peer = peers[-2:]
     store_peer, get_peer = peers[-2:]
@@ -41,7 +41,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
         endpoints.append(random_endpoint())
-        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], *endpoints[-1])
+        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1])
         total_store_time += time.perf_counter() - store_start
         total_store_time += time.perf_counter() - store_start
 
 
         total_stores += len(success_list)
         total_stores += len(success_list)
@@ -64,7 +64,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
 
         for i, expert in enumerate(get_result):
         for i, expert in enumerate(get_result):
             if expert is not None and expert.uid == expert_uids[start + i] \
             if expert is not None and expert.uid == expert_uids[start + i] \
-                    and (expert.host, expert.port) == endpoints[start // expert_batch_size]:
+                    and expert.endpoint == endpoints[start // expert_batch_size]:
                 successful_gets += 1
                 successful_gets += 1
 
 
     if time.perf_counter() - benchmark_started > expiration_time:
     if time.perf_counter() - benchmark_started > expiration_time:

+ 3 - 2
tests/benchmark_throughput.py

@@ -14,7 +14,7 @@ from hivemind import find_open_port
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     torch.set_num_threads(1)
     can_start.wait()
     can_start.wait()
-    experts = [hivemind.RemoteExpert(f"expert{i}", port=port) for i in range(num_experts)]
+    experts = [hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)]
 
 
     try:
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -69,7 +69,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
                                                            max_batch_size=max_batch_size,
                                                            max_batch_size=max_batch_size,
                                                            )
                                                            )
         timestamps['created_experts'] = time.perf_counter()
         timestamps['created_experts'] = time.perf_counter()
-        server = hivemind.Server(None, experts, port=port, conn_handler_processes=num_handlers, device=device)
+        server = hivemind.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
+                                 num_connection_handlers=num_handlers, device=device)
         server.start()
         server.start()
         server.ready.wait()
         server.ready.wait()
         timestamps['server_ready'] = time.perf_counter()
         timestamps['server_ready'] = time.perf_counter()

+ 15 - 12
tests/test_dht.py

@@ -5,7 +5,7 @@ import random
 import heapq
 import heapq
 import uuid
 import uuid
 from itertools import chain
 from itertools import chain
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 
 
 import hivemind
 import hivemind
@@ -68,11 +68,12 @@ def test_dht_protocol():
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-            assert recv_id == peer2_id and recv_endpoint == f"{LOCALHOST}:{peer2_port}", \
+            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
 
 
-            assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                                                                          f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+            assert recv_value == value and recv_expiration == expiration, \
+                f"call_find_value expected {value} (expires by {expiration}) " \
+                f"but got {recv_value} (expires by {recv_expiration})"
 
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
             dummy_key = DHTID.generate()
@@ -89,7 +90,7 @@ def test_dht_protocol():
 
 
             if listen:
             if listen:
                 loop.run_until_complete(protocol.shutdown())
                 loop.run_until_complete(protocol.shutdown())
-            print("DHTProtocol test finished sucessfully!")
+            print("DHTProtocol test finished successfully!")
             test_success.set()
             test_success.set()
 
 
     tester = mp.Process(target=_tester, daemon=True)
     tester = mp.Process(target=_tester, daemon=True)
@@ -178,13 +179,15 @@ def test_dht_node():
 
 
         # test 1: find self
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-        assert len(nearest) == 1 and nearest[me.node_id] == f"{LOCALHOST}:{me.port}"
+        assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
 
 
         # test 2: find others
         # test 2: find others
         for i in range(10):
         for i in range(10):
             ref_endpoint, query_id = random.choice(list(dht.items()))
             ref_endpoint, query_id = random.choice(list(dht.items()))
             nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
             nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
-            assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
+            assert len(nearest) == 1
+            found_node_id, found_endpoint = next(iter(nearest.items()))
+            assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
 
 
         # test 3: find neighbors to random nodes
         # test 3: find neighbors to random nodes
         accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
         accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
@@ -266,7 +269,7 @@ def test_hivemind_dht():
     peers = [hivemind.DHT(start=True)]
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(*neighbors_i, start=True))
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
 
 
     you: hivemind.dht.DHT = random.choice(peers)
     you: hivemind.dht.DHT = random.choice(peers)
     theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
     theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
@@ -281,10 +284,10 @@ def test_hivemind_dht():
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
 
     that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
     that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port)
+    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
     you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
     you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
     assert isinstance(you_found, hivemind.RemoteExpert)
     assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.host == 'that_host', you_found.port == that_guys_port
+    assert you_found.endpoint == f'that_host:{that_guys_port}'
 
 
     # test first_k_active
     # test first_k_active
     assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
     assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
@@ -302,9 +305,9 @@ def test_hivemind_dht():
 
 
 def test_dht_single_node():
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
     node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], hivemind.LOCALHOST, 1337))
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:{1337}"))
     for expert in node.get_experts(['e3', 'e2']):
     for expert in node.get_experts(['e3', 'e2']):
-        assert expert.host == hivemind.LOCALHOST and expert.port == 1337
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:{1337}"
     assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
     assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
 
 
 
 

+ 6 - 6
tests/test_moe.py

@@ -19,8 +19,8 @@ def test_remote_module_call():
     random_proj = torch.randn_like(xx)
     random_proj = torch.randn_like(xx)
 
 
     with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
     with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
-                           no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
-        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
+        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', endpoint=server_endpoint) for i in range(num_experts)]
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
             [(None,), {}], xx)
             [(None,), {}], xx)
@@ -51,8 +51,8 @@ def test_determinism():
     mask = torch.randint(0, 1, (32, 1024))
     mask = torch.randint(0, 1, (32, 1024))
 
 
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
-                           no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
-        expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
+        expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
 
 
         out = expert(xx, mask)
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -70,11 +70,11 @@ def test_compute_expert_scores():
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert')
             uid_prefix='expert')
-        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+        gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
         batch_experts = [
-            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}')
+            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
              for expert_i in range(len(ii[batch_i]))]
              for expert_i in range(len(ii[batch_i]))]
             for batch_i in range(len(ii))
             for batch_i in range(len(ii))
         ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores
         ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores

+ 4 - 3
tests/test_routing.py

@@ -115,16 +115,17 @@ def test_routing_table_search():
             k = random.randint(1, 100)
             k = random.randint(1, 100)
             query_id = DHTID.generate()
             query_id = DHTID.generate()
             exclude = query_id if random.random() < 0.5 else None
             exclude = query_id if random.random() < 0.5 else None
-            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
+            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
-            assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
+            assert all(our_endpoint == routing_table[our_node]
+                       for our_node, our_endpoint in zip(our_knn, our_endpoints))
 
 
         # queries from table
         # queries from table
         for i in range(1000):
         for i in range(1000):
             k = random.randint(1, 100)
             k = random.randint(1, 100)
             query_id = random.choice(all_active_neighbors)
             query_id = random.choice(all_active_neighbors)
-            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
+            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
 
 
             reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             if query_id in reference_knn:
             if query_id in reference_knn:

+ 4 - 7
tests/test_training.py

@@ -1,4 +1,3 @@
-#%env CUDA_VISIBLE_DEVICES=
 import argparse
 import argparse
 from typing import Optional
 from typing import Optional
 
 
@@ -6,21 +5,19 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
-from hivemind import RemoteExpert, find_open_port
+from hivemind import RemoteExpert, find_open_port, LOCALHOST
 from test_utils.run_server import background_server
 from test_utils.run_server import background_server
 
 
 from sklearn.datasets import load_digits
 from sklearn.datasets import load_digits
 
 
 
 
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
-    if port is None:
-        port = find_open_port()
     dataset = load_digits()
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
 
 
-    with background_server(num_experts=2, device='cpu', port=port, hidden_dim=64):
-        expert1 = RemoteExpert('expert.0', host='127.0.0.1', port=port)
-        expert2 = RemoteExpert('expert.1', host='127.0.0.1', port=port)
+    with background_server(num_experts=2, device='cpu', hidden_dim=64) as (server_endpoint, _):
+        expert1 = RemoteExpert('expert.0', server_endpoint)
+        expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
 
 
         opt = torch.optim.SGD(model.parameters(), lr=0.05)
         opt = torch.optim.SGD(model.parameters(), lr=0.05)

+ 19 - 13
tests/test_utils/run_server.py

@@ -3,20 +3,21 @@ import multiprocessing as mp
 from contextlib import contextmanager
 from contextlib import contextmanager
 
 
 import resource
 import resource
+from typing import Tuple
+
 import torch
 import torch
 
 
 import hivemind
 import hivemind
 from test_utils.layers import name_to_block, name_to_input
 from test_utils.layers import name_to_block, name_to_input
 
 
 
 
-def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
+def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
                       UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
     """
     """
     Instantiate a server with several identical experts. See argparse comments below for details
     Instantiate a server with several identical experts. See argparse comments below for details
-    :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
-    :param port: main server will listen to this port, default = find open port
+    :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param num_experts: run this many identical experts
     :param num_experts: run this many identical experts
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param hidden_dim: main dimension for expert_cls
     :param hidden_dim: main dimension for expert_cls
@@ -46,8 +47,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
     if not no_dht:
     if not no_dht:
         if not len(initial_peers):
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
             print("No initial peers provided. Starting additional dht as an initial peer.")
-            dht_root = hivemind.DHT(
-                *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}", start=True)
+            dht_root = hivemind.DHT(initial_peers=initial_peers,
+                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}",
+                                    start=True)
             print(f"Initializing DHT with port {dht_root.port}")
             print(f"Initializing DHT with port {dht_root.port}")
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
         else:
@@ -55,8 +57,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
             if root_port is not None:
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
 
-        dht = hivemind.DHT(
-            *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}", start=True)
+        dht = hivemind.DHT(initial_peers=initial_peers,
+                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}",
+                           start=True)
         if verbose:
         if verbose:
             print(f"Running dht node on port {dht.port}")
             print(f"Running dht node on port {dht.port}")
 
 
@@ -79,19 +82,19 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
                                                      )
                                                      )
     # actually start server
     # actually start server
     server = hivemind.Server(
     server = hivemind.Server(
-        dht, experts, addr=interface, port=port or hivemind.find_open_port(),
-        conn_handler_processes=num_handlers, device=device)
+        dht, experts, listen_on=listen_on,
+        num_connection_handlers=num_handlers, device=device)
 
 
     if start:
     if start:
         server.run_in_background(await_ready=True)
         server.run_in_background(await_ready=True)
         if verbose:
         if verbose:
-            print(f"Server started at {server.addr}:{server.port}")
+            print(f"Server started at {server.listen_on}")
             print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
             print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
     return server
     return server
 
 
 
 
 @contextmanager
 @contextmanager
-def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
+def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.get_context("spawn").Process(
     runner = mp.get_context("spawn").Process(
@@ -115,8 +118,11 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
 def _server_runner(pipe, *args, verbose, **kwargs):
 def _server_runner(pipe, *args, verbose, **kwargs):
     server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
     server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
     try:
     try:
-        dht_port = server.dht.port if server.dht is not None else None
-        pipe.send((server.addr, server.port, dht_port))
+        if server.dht is not None:
+            dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
+        else:
+            dht_listen_on = None
+        pipe.send((server.listen_on, dht_listen_on))
         pipe.recv()  # wait for shutdown signal
         pipe.recv()  # wait for shutdown signal
     finally:
     finally:
         if verbose:
         if verbose: