Explorar el Código

massive refactors (#74)

* remove JoblibSerializer

* rename SharedFuture to MPFuture

* move DUMMY to hivemind.expert, remove hivemind.utils.data

* style: line length <130

* changed CheckpointSaver.dir => checkpoint dir (rationale: same name everywhere, avoid shadowing builtin dir)

* rename SharedFuture to MPFuture

* rename conn_handler_processes => num_connection_handlers (rationale: to make it similar to num_experts, num_workers, num_replicas, num_threads)

* remove hivemind.utils.data

* typo sucessfully => successfully

* rename SharedFuture to MPFuture

* move hivemind.runtime.* into hivemind.server

* tuple endpoint => string endpoint (#76)

* wip: passed test_dht

* wip: passed test_training

* wip: all tests should work now

* wip: benchmark_dht works now

* wip: benchmark_throughput works now

* DHT now accepts initial_peers same way as DHTNode

* fix initial_peers in test_training

* review by @mryab

* review: move import multiprocessing.* to files that require them

* review: remove unused import ctypes

* review: add todo

* review: inline strip_endpoint
justheuristic hace 5 años
padre
commit
f496f2c14a

+ 1 - 4
docs/modules/server.rst

@@ -1,4 +1,4 @@
-``hivemind.server & runtime``
+**Hivemind Server**
 ========================================
 
 .. automodule:: hivemind.server
@@ -9,13 +9,10 @@
    :members:
    :member-order: bysource
 
-.. currentmodule:: hivemind.runtime
-
 .. autoclass:: Runtime
     :members:
     :member-order: bysource
 
-
 .. autoclass:: ExpertBackend
     :members: forward, backward, apply_gradients, get_info, get_pools
     :member-order: bysource

+ 1 - 2
hivemind/__init__.py

@@ -1,7 +1,6 @@
 from hivemind.client import *
 from hivemind.dht import *
-from hivemind.server import Server
+from hivemind.server import *
 from hivemind.utils import *
-from hivemind.runtime import *
 
 __version__ = '0.7.1'

+ 13 - 14
hivemind/client/expert.py

@@ -7,9 +7,11 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.utils import nested_flatten, DUMMY, nested_pack, nested_compare
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, Endpoint
 from hivemind.utils.grpc import serialize_torch_tensor, deserialize_torch_tensor, runtime_pb2, runtime_grpc
 
+DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
+
 
 class RemoteExpert(nn.Module):
     """
@@ -20,20 +22,18 @@ class RemoteExpert(nn.Module):
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
 
     :param uid: unique expert identifier
-    :param host: hostname where server operates
-    :param port: port to which server listens
+    :param endpoint: network endpoint of a server that services that expert, e.g. "201.123.321.99:1337" or "[::]:8080"
     """
 
-    def __init__(self, uid, host='127.0.0.1', port=8080):
+    def __init__(self, uid, endpoint: Endpoint):
         super().__init__()
-        self.uid, self.host, self.port = uid, host, port
-        self._channel, self._stub = None, None
-        self._info = None
+        self.uid, self.endpoint = uid, endpoint
+        self._channel, self._stub, self._info = None, None, None
 
     @property
     def stub(self):
         if self._channel is None:
-            self._channel = grpc.insecure_channel(f'{self.host}:{self.port}', options=[
+            self._channel = grpc.insecure_channel(self.endpoint, options=[
                 ('grpc.max_send_message_length', -1),
                 ('grpc.max_receive_message_length', -1)
             ])
@@ -57,8 +57,7 @@ class RemoteExpert(nn.Module):
         if not nested_compare(forward_inputs, self.info['forward_schema']):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
-        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, self.stub,
-                                               *nested_flatten(forward_inputs))
+        flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, *nested_flatten(forward_inputs))
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
         return nested_pack(flat_outputs, structure=self.info['outputs_schema'])
 
@@ -70,18 +69,18 @@ class RemoteExpert(nn.Module):
         return self._info
 
     def extra_repr(self):
-        return f"uid={self.uid}, host={self.host}, port={self.port}"
+        return f"uid={self.uid}, endpoint={self.endpoint}"
 
 
 class _RemoteModuleCall(torch.autograd.Function):
     """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """
 
     @staticmethod
-    def forward(ctx, dummy: torch.Tensor, uid: str, host: str, port: int, stub: runtime_grpc.ConnectionHandlerStub,
+    def forward(ctx, dummy: torch.Tensor, uid: str, stub: runtime_grpc.ConnectionHandlerStub,
                 *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
         inputs = tuple(map(torch.Tensor.detach, inputs))  # detach to avoid pickling the computation graph
-        ctx.uid, ctx.host, ctx.port, ctx.stub = uid, host, port, stub
+        ctx.uid, ctx.stub = uid, stub
         ctx.save_for_backward(*inputs)
 
         outputs = stub.forward(
@@ -100,4 +99,4 @@ class _RemoteModuleCall(torch.autograd.Function):
             runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[serialize_torch_tensor(tensor) for tensor in payload]))
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
-        return (DUMMY, None, None, None, None, *deserialized_grad_inputs)
+        return (DUMMY, None, None, *deserialized_grad_inputs)

+ 6 - 6
hivemind/client/moe.py

@@ -6,8 +6,8 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-from hivemind.client.expert import RemoteExpert, _RemoteModuleCall
-from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background, \
+from hivemind.client.expert import RemoteExpert, _RemoteModuleCall, DUMMY
+from hivemind.utils import nested_map, run_and_await_k, nested_pack, nested_flatten, run_in_background, \
     run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
 
 
@@ -43,7 +43,8 @@ class RemoteMixtureOfExperts(nn.Module):
         self.dht, self.grid_size = dht, grid_size
         self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
         self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
-        self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
+        self.forward_timeout, self.backward_timeout = forward_timeout, backward_timeout
+        self.timeout_after_k_min = timeout_after_k_min
         self.allow_broadcasting = allow_broadcasting
 
         self.proj = nn.Linear(in_features, sum(grid_size))  # jointly predict logits for all grid dimensions
@@ -258,11 +259,10 @@ class _RemoteMoECall(torch.autograd.Function):
     @staticmethod
     def _run_expert_forward(expert: RemoteExpert, *args: torch.Tensor, **kwargs: torch.Tensor):
         """ Call remote expert and return flattened outputs. Compatible with concurrent autograd. """
-        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.host, expert.port, expert.stub,
-                                    *nested_flatten((args, kwargs)))
+        return run_isolated_forward(_RemoteModuleCall, DUMMY, expert.uid, expert.stub, *nested_flatten((args, kwargs)))
 
     @staticmethod
     def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
         backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
-        grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, no_grad_stub, *grad_inputs = backward_result
+        grad_dummy, no_grad_uid, no_grad_stub, *grad_inputs = backward_result
         return grad_inputs

+ 20 - 21
hivemind/dht/__init__.py

@@ -16,21 +16,21 @@ import asyncio
 import ctypes
 import multiprocessing as mp
 import warnings
-from typing import List, Optional
+from typing import List, Optional, Sequence
 
 import uvloop
 
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import SharedFuture, Endpoint, run_in_background
+from hivemind.utils import MPFuture, Endpoint, run_in_background
 
 
 class DHT(mp.Process):
     """
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
 
-    :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
+    :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
@@ -42,12 +42,12 @@ class DHT(mp.Process):
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
 
-    def __init__(self, *initial_peers: Endpoint, listen_on: Endpoint = "0.0.0.0:*", start: bool, daemon: bool = True,
-                 max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
+    def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
+                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
         super().__init__()
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
-        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after server starts
+        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self.node: Optional[DHTNode] = None  # initialized inside self.run only
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
@@ -99,11 +99,11 @@ class DHT(mp.Process):
         :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
         :returns: a list of [RemoteExpert if found else None]
         """
-        future, _future = SharedFuture.make_pair()
+        future, _future = MPFuture.make_pair()
         self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
         return future.result()
 
-    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
+    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
         loop = asyncio.get_event_loop()
         expiration = expiration or get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
@@ -114,41 +114,40 @@ class DHT(mp.Process):
 
         experts: List[Optional[RemoteExpert]] = [None] * len(uids)
         for i, (key, uid) in enumerate(zip(keys, uids)):
-            maybe_result, maybe_expiration = response[key]
+            maybe_endpoint, maybe_expiration = response[key]
             if maybe_expiration is not None:  # if we found a value
-                experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
+                experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
 
         future.set_result(experts)
 
-    def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
+    def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
-        Make experts available to DHT; update timestamps if already available
+        Make experts visible to all DHT peers; update timestamps if declared previously.
 
         :param uids: a list of expert ids to update
-        :param addr: hostname that can be used to call this expert
-        :param port: port that can be used to call this expert
+        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
         :param wait: if True, awaits for declaration to finish, otherwise runs in background
         :param timeout: waits for the procedure to finish, None means wait indeninitely
         :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         """
-        future, _future = SharedFuture.make_pair() if wait else (None, None)
-        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, future=_future)))
+        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
+        future, _future = MPFuture.make_pair() if wait else (None, None)
+        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
         if wait:
             return future.result(timeout)
 
-    def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
+    def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         assert self.node is not None, "This method should only be accessed from inside .run method"
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         loop = asyncio.get_event_loop()
         expiration_time = get_dht_time() + self.EXPIRATION
         unique_prefixes = set()
-        coroutines = []
 
         keys, values = [], []
         for uid in uids:
             uid_parts = uid.split(self.UID_DELIMETER)
             keys.append(self.make_key('expert', uid))
-            values.append((addr, port))
+            values.append(endpoint)
             unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
 
         for prefix in unique_prefixes:
@@ -171,12 +170,12 @@ class DHT(mp.Process):
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         """
         assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
-        future, _future = SharedFuture.make_pair()
+        future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
                         dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
         return future.result()
 
-    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: SharedFuture):
+    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
         assert self.node is not None, "This method should only be accessed from inside .run method"
         max_prefetch = max_prefetch or len(prefixes)
         loop = asyncio.get_event_loop()

+ 6 - 6
hivemind/dht/node.py

@@ -314,7 +314,7 @@ class DHTNode:
 
         # search metadata
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
-        node_to_addr: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
+        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
 
         SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
@@ -331,18 +331,18 @@ class DHTNode:
 
         # stage 2: traverse the DHT for any unfinished keys
         for key_id in unfinished_key_ids:
-            node_to_addr.update(self.protocol.routing_table.get_nearest_neighbors(
+            node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
             queries = list(queries)
-            response = await self.protocol.call_find(node_to_addr[peer], queries)
+            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
             for key_id, (maybe_value, maybe_expiration, peers) in response.items():
-                node_to_addr.update(peers)
+                node_to_endpoint.update(peers)
                 if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
                     latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
                 should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
@@ -350,7 +350,7 @@ class DHTNode:
             return output
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_addr),
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
             beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
 
@@ -367,7 +367,7 @@ class DHTNode:
                     if node_id == latest_node_id:
                         continue
                     asyncio.create_task(self.protocol.call_store(
-                        node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
+                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
                     num_cached_nodes += 1
                     if num_cached_nodes >= self.cache_nearest:
                         break

+ 5 - 7
hivemind/dht/protocol.py

@@ -4,7 +4,6 @@ from __future__ import annotations
 import asyncio
 import heapq
 import os
-import urllib.parse
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 
@@ -12,7 +11,7 @@ import grpc
 import grpc.experimental.aio
 
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from hivemind.utils import Endpoint, compile_grpc, get_logger
+from hivemind.utils import Endpoint, compile_grpc, get_logger, replace_port, get_port
 
 logger = get_logger(__name__)
 
@@ -40,7 +39,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         See DHTNode (node.py) for a more detailed description.
 
         :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
-         for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
+         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
          Only the call_* methods are meant to be called publicly, e.g. from DHTNode
          Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
         """
@@ -109,9 +108,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ Some node wants us to add it to our routing table. """
         if peer_info.node_id and peer_info.rpc_port:
             sender_id = DHTID.from_bytes(peer_info.node_id)
-            peer_url = urllib.parse.urlparse(context.peer())
-            address = peer_url.path[:peer_url.path.rindex(':')]
-            asyncio.create_task(self.update_routing_table(sender_id, f"{address}:{peer_info.rpc_port}"))
+            rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port)
+            asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
         return self.node_info
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
@@ -193,7 +191,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 output[key] = (value, expiration, nearest)
             return output
         except grpc.experimental.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:

+ 10 - 10
hivemind/dht/routing.py

@@ -47,7 +47,7 @@ class RoutingTable:
 
     def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
-        Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
+        Update routing table after an incoming request from :endpoint: or outgoing request to :endpoint:
 
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
@@ -91,7 +91,7 @@ class RoutingTable:
         """ Find endpoint for a given DHTID or vice versa """
         return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
 
-    def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
+    def __setitem__(self, node_id: DHTID, endpoint: Endpoint) -> NotImplementedError:
         raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
 
     def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
@@ -160,7 +160,7 @@ class RoutingTable:
 class KBucket:
     """
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
-    Maps DHT node ids to their endpoints (hostname, addr)
+    Maps DHT node ids to their endpoints
     """
 
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
@@ -175,13 +175,13 @@ class KBucket:
         """ Check if node_id is between this bucket's lower and upper bounds """
         return self.lower <= node_id < self.upper
 
-    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> bool:
+    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> bool:
         """
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
 
         :param node_id: dht node identifier that should be added or moved to the front of bucket
-        :param addr: a pair of (hostname, port) associated with that node id
+        :param endpoint: network address associated with that node id
         :note: this function has a side-effect of resetting KBucket.last_updated time
         """
         if node_id in self.nodes_requested_for_ping:
@@ -189,13 +189,13 @@ class KBucket:
         self.last_updated = get_dht_time()
         if node_id in self.nodes_to_endpoint:
             del self.nodes_to_endpoint[node_id]
-            self.nodes_to_endpoint[node_id] = addr
+            self.nodes_to_endpoint[node_id] = endpoint
         elif len(self.nodes_to_endpoint) < self.size:
-            self.nodes_to_endpoint[node_id] = addr
+            self.nodes_to_endpoint[node_id] = endpoint
         else:
             if node_id in self.replacement_nodes:
                 del self.replacement_nodes[node_id]
-            self.replacement_nodes[node_id] = addr
+            self.replacement_nodes[node_id] = endpoint
             return False
         return True
 
@@ -229,9 +229,9 @@ class KBucket:
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
-        for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
+        for node_id, endpoint in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
             bucket = left if int(node_id) <= midpoint else right
-            bucket.add_or_update_node(node_id, addr)
+            bucket.add_or_update_node(node_id, endpoint)
         return left, right
 
     def __repr__(self):

+ 18 - 12
hivemind/server/__init__.py

@@ -1,12 +1,16 @@
 import multiprocessing as mp
+import multiprocessing.synchronize
 import threading
 from typing import Dict, Optional
 
 from hivemind.dht import DHT
-from hivemind.runtime import Runtime, ExpertBackend
+from hivemind.server.runtime import Runtime
+from hivemind.server.task_pool import Task, TaskPool, TaskPoolBase
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.server.checkpoint_saver import CheckpointSaver
 from hivemind.server.connection_handler import ConnectionHandler
 from hivemind.server.dht_handler import DHTHandlerThread
+from hivemind.utils import Endpoint, get_port, replace_port, find_open_port
 
 
 class Server(threading.Thread):
@@ -20,11 +24,10 @@ class Server(threading.Thread):
      - follows orders from HivemindController - if it exists
 
     :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
-     but it will still support accessing experts directly with RemoteExpert(uid=UID, host=IPADDR, port=PORT).
+     but it will still support accessing experts directly with RemoteExpert(uid=UID, endpoint="IPADDR:PORT").
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
-    :param addr: server's dht address that determines how it can be accessed. Default is local connections only.
-    :param port: port to which server listens for requests such as expert forward or backward pass.
-    :param conn_handler_processes: maximum number of simultaneous requests. Please note that the default value of 1
+    :param listen_on: server's dht address that determines how it can be accessed. Address and (optional) port
+    :param num_connection_handlers: maximum number of simultaneous requests. Please note that the default value of 1
         if too small for normal functioning, we recommend 4 handlers per expert backend.
     :param update_period: how often will server attempt to publish its state (i.e. experts) to the DHT;
         if dht is None, this parameter is ignored.
@@ -32,13 +35,16 @@ class Server(threading.Thread):
         is ready (see .ready below)
     """
 
-    def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
-                 port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
+    def __init__(
+            self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], listen_on: Endpoint = "0.0.0.0:*",
+            num_connection_handlers: int = 1, update_period: int = 30, start=False, checkpoint_dir=None, **kwargs):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        self.addr, self.port = addr, port
-        self.conn_handlers = [ConnectionHandler(f"{self.addr}:{port}", self.experts)
-                              for _ in range(conn_handler_processes)]
+        if get_port(listen_on) is None:
+            self.listen_on = listen_on = replace_port(listen_on, new_port=find_open_port())
+        self.port = get_port(listen_on)
+
+        self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         if checkpoint_dir is not None:
             self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
         else:
@@ -57,8 +63,8 @@ class Server(threading.Thread):
             if not self.dht.is_alive():
                 self.dht.run_in_background(await_ready=True)
 
-            dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht,
-                                                  addr=self.addr, port=self.port, update_period=self.update_period)
+            dht_handler_thread = DHTHandlerThread(
+                experts=self.experts, dht=self.dht, endpoint=self.listen_on, update_period=self.update_period)
             dht_handler_thread.start()
         if self.checkpoint_saver is not None:
             self.checkpoint_saver.start()

+ 4 - 4
hivemind/server/checkpoint_saver.py

@@ -8,20 +8,20 @@ from typing import Dict
 
 import torch
 
-from hivemind.runtime import ExpertBackend
+from hivemind.server.expert_backend import ExpertBackend
 
 
 class CheckpointSaver(threading.Thread):
-    def __init__(self, expert_backends: Dict[str, ExpertBackend], dir: Path, update_period: int):
+    def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
         super().__init__()
         self.expert_backends = expert_backends
         self.update_period = update_period
-        self.dir = dir
+        self.checkpoint_dir = checkpoint_dir
         self.stop = False
 
     def run(self) -> None:
         while not self.stop:
-            store_experts(self.expert_backends, self.dir)
+            store_experts(self.expert_backends, self.checkpoint_dir)
             time.sleep(self.update_period)
 
 

+ 1 - 1
hivemind/server/connection_handler.py

@@ -8,7 +8,7 @@ import grpc.experimental.aio
 import torch
 import uvloop
 
-from hivemind.runtime.expert_backend import ExpertBackend
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger, serialize_torch_tensor, deserialize_torch_tensor, Endpoint, runtime_pb2, runtime_grpc
 
 logger = get_logger(__name__)

+ 5 - 4
hivemind/server/dht_handler.py

@@ -2,13 +2,14 @@ import threading
 import time
 
 from hivemind.dht import DHT
+from hivemind.utils import Endpoint, get_port
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
+    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5):
         super(DHTHandlerThread, self).__init__()
-        self.port = port
-        self.addr = addr
+        assert get_port(endpoint) is not None
+        self.endpoint = endpoint
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
@@ -16,5 +17,5 @@ class DHTHandlerThread(threading.Thread):
 
     def run(self) -> None:
         while not self.stop:
-            self.dht.declare_experts(self.experts.keys(), self.addr, self.port)
+            self.dht.declare_experts(self.experts.keys(), self.endpoint)
             time.sleep(self.update_period)

+ 10 - 7
hivemind/runtime/expert_backend.py → hivemind/server/expert_backend.py

@@ -3,8 +3,9 @@ from typing import Dict, Sequence, Any, Tuple, Union
 import torch
 from torch import nn
 
-from hivemind.runtime.task_pool import TaskPool
-from hivemind.utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
+from hivemind.server.task_pool import TaskPool
+from hivemind.utils import nested_flatten, nested_pack, nested_compare, nested_map,\
+    BatchTensorDescriptor, DUMMY_BATCH_SIZE
 
 
 class ExpertBackend(nn.Module):
@@ -18,7 +19,7 @@ class ExpertBackend(nn.Module):
 
     :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
-     - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
+     - Experts must always receive the same set of args and kwargs and produce output tensors of same type
      - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
      - We recommend using experts that are ~invariant to the order in which they process batches
      - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
@@ -95,15 +96,17 @@ class ExpertBackend(nn.Module):
 
            .. todo correct state handling (see forward)
 
-           Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
+           Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
         """
         (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
 
         with torch.enable_grad():
             args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
                     else tensor.detach() for tensor in args]
-            kwargs = {input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
-                                  else tensor.detach()) for input_key, tensor in kwargs.items()}
+            kwargs = {input_key: (tensor.detach().requires_grad_(True)
+                                  if tensor.dtype in (torch.half, torch.float, torch.double)
+                                  else tensor.detach())
+                      for input_key, tensor in kwargs.items()}
 
             outputs = self.expert(*args, **kwargs)
             assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
@@ -122,7 +125,7 @@ class ExpertBackend(nn.Module):
 
     def apply_gradients(self) -> None:
         """
-        Train the expert for a single step. This method is called by ``ExpertBackend.backward`` after computing gradients.
+        Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
         self.opt.step()
         self.opt.zero_grad()

+ 2 - 2
hivemind/runtime/__init__.py → hivemind/server/runtime.py

@@ -1,4 +1,5 @@
 import multiprocessing as mp
+import multiprocessing.pool
 import threading
 from itertools import chain
 from selectors import DefaultSelector, EVENT_READ
@@ -7,8 +8,7 @@ from typing import Dict
 import torch
 from prefetch_generator import BackgroundGenerator
 
-from hivemind.runtime.expert_backend import ExpertBackend
-from hivemind.runtime.task_pool import TaskPool, TaskPoolBase
+from hivemind.server.expert_backend import ExpertBackend
 from hivemind.utils import get_logger
 
 logger = get_logger(__name__)

+ 5 - 4
hivemind/runtime/task_pool.py → hivemind/server/task_pool.py

@@ -3,6 +3,7 @@ Task pool is responsible for receiving tasks and grouping them together for proc
 """
 import ctypes
 import multiprocessing as mp
+import multiprocessing.context
 import os
 import threading
 import time
@@ -14,7 +15,7 @@ from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
-from hivemind.utils import SharedFuture, get_logger
+from hivemind.utils import MPFuture, get_logger
 
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
@@ -86,10 +87,10 @@ class TaskPool(TaskPoolBase):
 
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
-        future1, future2 = SharedFuture.make_pair()
+        future1, future2 = MPFuture.make_pair()
         task = Task(future1, args)
         if self.get_task_size(task) > self.max_batch_size:
-            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it will never be finished")
+            exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             future2.set_exception(exc)
         else:
             self.tasks.put(task)
@@ -127,7 +128,7 @@ class TaskPool(TaskPoolBase):
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
         logger.info(f'{self.uid} starting, pid={os.getpid()}')
-        pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
+        pending_batches = {}  # Dict[batch uuid, List[MPFuture]] for each batch currently in runtime
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
                                          name=f'{self.uid}_output')
         try:

+ 2 - 3
hivemind/utils/__init__.py

@@ -1,9 +1,8 @@
-from hivemind.utils.connection import *
-from hivemind.utils.data import *
+from hivemind.utils.networking import *
 from hivemind.utils.nested import *
 from hivemind.utils.tensor_descr import *
 from hivemind.utils.serializer import *
-from hivemind.utils.shared_future import *
+from hivemind.utils.mpfuture import *
 from hivemind.utils.threading import *
 from hivemind.utils.autograd import *
 from hivemind.utils.grpc import *

+ 2 - 1
hivemind/utils/autograd.py

@@ -92,7 +92,8 @@ class _ParallelApplyFunction(torch.autograd.Function):
     @staticmethod
     def backward(ctx, *grad_outputs_flat: torch.Tensor):
         func, contexts, output_strides = ctx._inner_func, ctx._call_contexts, ctx._output_strides
-        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]] for i in range(len(contexts))]
+        grad_outputs_per_call = [grad_outputs_flat[output_strides[i]: output_strides[i + 1]]
+                                 for i in range(len(contexts))]
         futures = [run_in_background(run_isolated_backward, func, context, *grads)
                    for context, grads in zip(contexts, grad_outputs_per_call)]
         flat_grads_wrt_input = tuple(grad for future in futures for grad in future.result())

+ 0 - 3
hivemind/utils/data.py

@@ -1,3 +0,0 @@
-import torch
-
-DUMMY = torch.empty(0, requires_grad=True)

+ 2 - 1
hivemind/utils/grpc.py

@@ -47,7 +47,8 @@ def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
                 raise ImportError("Something changed sys.path while compile_grpc was in progress.")
 
 
-with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'server', 'connection_handler.proto')) as f_proto:
+with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+                       'server', 'connection_handler.proto')) as f_proto:
     runtime_pb2, runtime_grpc = compile_grpc(f_proto.read())
 
 

+ 2 - 2
hivemind/utils/logging.py

@@ -8,8 +8,8 @@ def get_logger(module_name: str) -> logging.Logger:
     loglevel = os.getenv('LOGLEVEL', 'INFO')
 
     logging.addLevelName(logging.WARNING, 'WARN')
-    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}', style='{',
-                                  datefmt='%Y/%m/%d %H:%M:%S')
+    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}',
+                                  style='{', datefmt='%Y/%m/%d %H:%M:%S')
     handler = logging.StreamHandler()
     handler.setFormatter(formatter)
     logger = logging.getLogger(name_without_prefix)

+ 1 - 1
hivemind/utils/shared_future.py → hivemind/utils/mpfuture.py

@@ -5,7 +5,7 @@ from warnings import warn
 import asyncio
 
 
-class SharedFuture(Future):
+class MPFuture(Future):
     """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
     STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
     STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES

+ 16 - 0
hivemind/utils/connection.py → hivemind/utils/networking.py

@@ -1,11 +1,27 @@
 import socket
+import urllib.parse
 from contextlib import closing
+from typing import Optional
 
 Hostname, Port = str, int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 
 
+def get_port(endpoint: Endpoint) -> Optional[Port]:
+    """ get port or None if port is undefined """
+    # TODO: find a standard way to get port, make sure it works in malformed ports
+    try:
+        return int(endpoint[endpoint.rindex(':') + 1:], base=10)
+    except ValueError:  # :* or not specified
+        return None
+
+
+def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
+    assert endpoint.endswith(':*') or get_port(endpoint) is not None, endpoint
+    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
+
+
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:

+ 0 - 17
hivemind/utils/serializer.py

@@ -2,11 +2,9 @@
 import pickle
 from io import BytesIO
 
-import joblib
 import torch
 import umsgpack
 
-
 class SerializerBase:
     @staticmethod
     def dumps(obj: object) -> bytes:
@@ -17,19 +15,6 @@ class SerializerBase:
         raise NotImplementedError()
 
 
-class JoblibSerializer(SerializerBase):
-
-    @staticmethod
-    def dumps(obj: object) -> bytes:
-        s = BytesIO()
-        joblib.dump(obj, s)
-        return s.getvalue()
-
-    @staticmethod
-    def loads(buf: bytes) -> object:
-        return joblib.load(BytesIO(buf))
-
-
 class PickleSerializer(SerializerBase):
     @staticmethod
     def dumps(obj: object) -> bytes:
@@ -41,7 +26,6 @@ class PickleSerializer(SerializerBase):
 
 
 class PytorchSerializer(SerializerBase):
-
     @staticmethod
     def dumps(obj: object) -> bytes:
         s = BytesIO()
@@ -54,7 +38,6 @@ class PytorchSerializer(SerializerBase):
 
 
 class MSGPackSerializer(SerializerBase):
-
     @staticmethod
     def dumps(obj: object) -> bytes:
         return umsgpack.dumps(obj, use_bin_type=False)  # TODO strict https://github.com/msgpack/msgpack-python/pull/158

+ 2 - 1
hivemind/utils/tensor_descr.py

@@ -45,7 +45,8 @@ class BatchTensorDescriptor(TensorDescriptor):
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
-                   device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=torch.cuda.is_available() and tensor.is_pinned())
+                   device=tensor.device, requires_grad=tensor.requires_grad,
+                   pin_memory=torch.cuda.is_available() and tensor.is_pinned())
 
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"

+ 0 - 1
requirements.txt

@@ -1,5 +1,4 @@
 torch>=1.3.0
-joblib>=0.13
 numpy>=1.17
 prefetch_generator>=1.0.1
 umsgpack

+ 6 - 6
tests/benchmark_dht.py

@@ -9,9 +9,9 @@ from tqdm import trange
 from test_utils import increase_file_limit
 
 
-def random_endpoint() -> Tuple[str, int]:
-    return (f"{random.randint(0, 256)}.{random.randint(0, 256)}."
-            f"{random.randint(0, 256)}.{random.randint(0, 256)}", random.randint(0, 65535))
+def random_endpoint() -> hivemind.Endpoint:
+    return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
+           f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
 
 
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
@@ -23,7 +23,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     peers = []
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peer = hivemind.DHT(*neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
+        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
         peers.append(peer)
 
     store_peer, get_peer = peers[-2:]
@@ -41,7 +41,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     for start in trange(0, num_experts, expert_batch_size):
         store_start = time.perf_counter()
         endpoints.append(random_endpoint())
-        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], *endpoints[-1])
+        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1])
         total_store_time += time.perf_counter() - store_start
 
         total_stores += len(success_list)
@@ -64,7 +64,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
         for i, expert in enumerate(get_result):
             if expert is not None and expert.uid == expert_uids[start + i] \
-                    and (expert.host, expert.port) == endpoints[start // expert_batch_size]:
+                    and expert.endpoint == endpoints[start // expert_batch_size]:
                 successful_gets += 1
 
     if time.perf_counter() - benchmark_started > expiration_time:

+ 3 - 2
tests/benchmark_throughput.py

@@ -14,7 +14,7 @@ from hivemind import find_open_port
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
     torch.set_num_threads(1)
     can_start.wait()
-    experts = [hivemind.RemoteExpert(f"expert{i}", port=port) for i in range(num_experts)]
+    experts = [hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)]
 
     try:
         dummy_batch = torch.randn(batch_size, hid_dim)
@@ -69,7 +69,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
                                                            max_batch_size=max_batch_size,
                                                            )
         timestamps['created_experts'] = time.perf_counter()
-        server = hivemind.Server(None, experts, port=port, conn_handler_processes=num_handlers, device=device)
+        server = hivemind.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
+                                 num_connection_handlers=num_handlers, device=device)
         server.start()
         server.ready.wait()
         timestamps['server_ready'] = time.perf_counter()

+ 15 - 12
tests/test_dht.py

@@ -5,7 +5,7 @@ import random
 import heapq
 import uuid
 from itertools import chain
-from typing import Optional
+from typing import Optional, Tuple
 import numpy as np
 
 import hivemind
@@ -68,11 +68,12 @@ def test_dht_protocol():
                 protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
             recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
             (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-            assert recv_id == peer2_id and recv_endpoint == f"{LOCALHOST}:{peer2_port}", \
+            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
 
-            assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                                                                          f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+            assert recv_value == value and recv_expiration == expiration, \
+                f"call_find_value expected {value} (expires by {expiration}) " \
+                f"but got {recv_value} (expires by {recv_expiration})"
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
@@ -89,7 +90,7 @@ def test_dht_protocol():
 
             if listen:
                 loop.run_until_complete(protocol.shutdown())
-            print("DHTProtocol test finished sucessfully!")
+            print("DHTProtocol test finished successfully!")
             test_success.set()
 
     tester = mp.Process(target=_tester, daemon=True)
@@ -178,13 +179,15 @@ def test_dht_node():
 
         # test 1: find self
         nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-        assert len(nearest) == 1 and nearest[me.node_id] == f"{LOCALHOST}:{me.port}"
+        assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
 
         # test 2: find others
         for i in range(10):
             ref_endpoint, query_id = random.choice(list(dht.items()))
             nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
-            assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
+            assert len(nearest) == 1
+            found_node_id, found_endpoint = next(iter(nearest.items()))
+            assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
 
         # test 3: find neighbors to random nodes
         accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
@@ -266,7 +269,7 @@ def test_hivemind_dht():
     peers = [hivemind.DHT(start=True)]
     for i in range(10):
         neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(hivemind.DHT(*neighbors_i, start=True))
+        peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
 
     you: hivemind.dht.DHT = random.choice(peers)
     theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
@@ -281,10 +284,10 @@ def test_hivemind_dht():
     assert all(res is None for res in found[-2:]), "Found non-existing experts"
 
     that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
-    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port)
+    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
     you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
     assert isinstance(you_found, hivemind.RemoteExpert)
-    assert you_found.host == 'that_host', you_found.port == that_guys_port
+    assert you_found.endpoint == f'that_host:{that_guys_port}'
 
     # test first_k_active
     assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
@@ -302,9 +305,9 @@ def test_hivemind_dht():
 
 def test_dht_single_node():
     node = hivemind.DHT(start=True)
-    assert all(node.declare_experts(['e1', 'e2', 'e3'], hivemind.LOCALHOST, 1337))
+    assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:{1337}"))
     for expert in node.get_experts(['e3', 'e2']):
-        assert expert.host == hivemind.LOCALHOST and expert.port == 1337
+        assert expert.endpoint == f"{hivemind.LOCALHOST}:{1337}"
     assert node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2) == ['e1', 'e3']
 
 

+ 6 - 6
tests/test_moe.py

@@ -19,8 +19,8 @@ def test_remote_module_call():
     random_proj = torch.randn_like(xx)
 
     with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
-                           no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
-        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
+        experts = [hivemind.RemoteExpert(uid=f'expert.{i}', endpoint=server_endpoint) for i in range(num_experts)]
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
             logits, experts[:len(logits)], k_min, timeout_after_k_min, backward_k_min, timeout_total, backward_timeout,
             [(None,), {}], xx)
@@ -51,8 +51,8 @@ def test_determinism():
     mask = torch.randint(0, 1, (32, 1024))
 
     with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
-                           no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
-        expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)
+                           no_optimizer=True, no_dht=True) as (server_endpoint, dht_endpoint):
+        expert = hivemind.RemoteExpert(uid=f'expert.0', endpoint=server_endpoint)
 
         out = expert(xx, mask)
         out_rerun = expert(xx, mask)
@@ -70,11 +70,11 @@ def test_compute_expert_scores():
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert')
-        gx, gy = torch.randn(4, 5, requires_grad=True), torch.torch.randn(4, 3, requires_grad=True)
+        gx, gy = torch.randn(4, 5, requires_grad=True), torch.randn(4, 3, requires_grad=True)
         ii = [[4, 0, 2], [3, 1, 1, 1, 3], [0], [3, 2]]
         jj = [[2, 2, 1], [0, 1, 2, 0, 1], [0], [1, 2]]
         batch_experts = [
-            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}')
+            [hivemind.RemoteExpert(uid=f'expert.{ii[batch_i][expert_i]}.{jj[batch_i][expert_i]}', endpoint="[::]:1337")
              for expert_i in range(len(ii[batch_i]))]
             for batch_i in range(len(ii))
         ]  # note: these experts do not exists on server, we use them only to test moe compute_expert_scores

+ 4 - 3
tests/test_routing.py

@@ -115,16 +115,17 @@ def test_routing_table_search():
             k = random.randint(1, 100)
             query_id = DHTID.generate()
             exclude = query_id if random.random() < 0.5 else None
-            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
+            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
-            assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
+            assert all(our_endpoint == routing_table[our_node]
+                       for our_node, our_endpoint in zip(our_knn, our_endpoints))
 
         # queries from table
         for i in range(1000):
             k = random.randint(1, 100)
             query_id = random.choice(all_active_neighbors)
-            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
+            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
 
             reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             if query_id in reference_knn:

+ 4 - 7
tests/test_training.py

@@ -1,4 +1,3 @@
-#%env CUDA_VISIBLE_DEVICES=
 import argparse
 from typing import Optional
 
@@ -6,21 +5,19 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from hivemind import RemoteExpert, find_open_port
+from hivemind import RemoteExpert, find_open_port, LOCALHOST
 from test_utils.run_server import background_server
 
 from sklearn.datasets import load_digits
 
 
 def test_training(port: Optional[int] = None, max_steps: int = 100, threshold: float = 0.9):
-    if port is None:
-        port = find_open_port()
     dataset = load_digits()
     X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
 
-    with background_server(num_experts=2, device='cpu', port=port, hidden_dim=64):
-        expert1 = RemoteExpert('expert.0', host='127.0.0.1', port=port)
-        expert2 = RemoteExpert('expert.1', host='127.0.0.1', port=port)
+    with background_server(num_experts=2, device='cpu', hidden_dim=64) as (server_endpoint, _):
+        expert1 = RemoteExpert('expert.0', server_endpoint)
+        expert2 = RemoteExpert('expert.1', server_endpoint)
         model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
 
         opt = torch.optim.SGD(model.parameters(), lr=0.05)

+ 19 - 13
tests/test_utils/run_server.py

@@ -3,20 +3,21 @@ import multiprocessing as mp
 from contextlib import contextmanager
 
 import resource
+from typing import Tuple
+
 import torch
 
 import hivemind
 from test_utils.layers import name_to_block, name_to_input
 
 
-def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
+def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
     """
     Instantiate a server with several identical experts. See argparse comments below for details
-    :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
-    :param port: main server will listen to this port, default = find open port
+    :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param num_experts: run this many identical experts
     :param expert_cls: expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop';
     :param hidden_dim: main dimension for expert_cls
@@ -46,8 +47,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
     if not no_dht:
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
-            dht_root = hivemind.DHT(
-                *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}", start=True)
+            dht_root = hivemind.DHT(initial_peers=initial_peers,
+                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}",
+                                    start=True)
             print(f"Initializing DHT with port {dht_root.port}")
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
@@ -55,8 +57,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
-        dht = hivemind.DHT(
-            *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}", start=True)
+        dht = hivemind.DHT(initial_peers=initial_peers,
+                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}",
+                           start=True)
         if verbose:
             print(f"Running dht node on port {dht.port}")
 
@@ -79,19 +82,19 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
                                                      )
     # actually start server
     server = hivemind.Server(
-        dht, experts, addr=interface, port=port or hivemind.find_open_port(),
-        conn_handler_processes=num_handlers, device=device)
+        dht, experts, listen_on=listen_on,
+        num_connection_handlers=num_handlers, device=device)
 
     if start:
         server.run_in_background(await_ready=True)
         if verbose:
-            print(f"Server started at {server.addr}:{server.port}")
+            print(f"Server started at {server.listen_on}")
             print(f"Got {num_experts} active experts of type {expert_cls}: {list(experts.keys())}")
     return server
 
 
 @contextmanager
-def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
+def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs) -> Tuple[hivemind.Endpoint, hivemind.Endpoint]:
     """ A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit """
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.get_context("spawn").Process(
@@ -115,8 +118,11 @@ def background_server(*args, shutdown_timeout=5, verbose=True, **kwargs):
 def _server_runner(pipe, *args, verbose, **kwargs):
     server = make_dummy_server(*args, verbose=verbose, start=True, **kwargs)
     try:
-        dht_port = server.dht.port if server.dht is not None else None
-        pipe.send((server.addr, server.port, dht_port))
+        if server.dht is not None:
+            dht_listen_on = hivemind.replace_port(server.dht.listen_on, server.dht.port)
+        else:
+            dht_listen_on = None
+        pipe.send((server.listen_on, dht_listen_on))
         pipe.recv()  # wait for shutdown signal
     finally:
         if verbose: