Răsfoiți Sursa

add benchmarks

Denis Mazur 4 ani în urmă
părinte
comite
87012861ff

+ 34 - 19
benchmarks/benchmark_throughput-p2p.py → benchmarks/benchmark_throughput_p2p.py

@@ -7,7 +7,7 @@ import time
 import torch
 
 import hivemind
-from hivemind import get_free_port
+from hivemind.dht import DHT
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
@@ -28,11 +28,20 @@ def print_device_info(device=None):
         logger.info(f"Cached:   {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
 
 
-def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+def client_process(
+    can_start,
+    benchmarking_failed,
+    server_peer_info,
+    num_experts,
+    batch_size,
+    hid_dim,
+    num_batches,
+    backprop=True,
+) -> None:
     torch.set_num_threads(1)
     can_start.wait()
     experts = [
-        hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
     ]
 
     try:
@@ -48,17 +57,16 @@ def client_process(can_start, benchmarking_failed, port, num_experts, batch_size
 
 
 def benchmark_throughput(
-    num_experts=16,
+    num_experts=1,
     num_handlers=None,
-    num_clients=128,
+    num_clients=1,
     num_batches_per_client=16,
     expert_cls="ffn",
     hid_dim=1024,
-    batch_size=2048,
+    batch_size=16,
     max_batch_size=None,
     backprop=True,
     device=None,
-    port=None,
 ):
     assert (
         not hasattr(torch.cuda, "is_initialized")
@@ -66,7 +74,6 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in layers.name_to_block
-    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
@@ -74,8 +81,12 @@ def benchmark_throughput(
     timestamps = dict(started=time.perf_counter())
 
     try:
-        # start clients and await server
-        # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
+        server_dht = DHT(start=True)
+        server_dht_peer_info = hivemind.PeerInfo(
+            peer_id=server_dht.peer_id,
+            addrs=[addr.decapsulate("/p2p/" + addr.get("p2p")) for addr in server_dht.get_visible_maddrs()],
+        )
+
         clients = [
             mp.Process(
                 target=client_process,
@@ -83,30 +94,29 @@ def benchmark_throughput(
                 args=(
                     can_start,
                     benchmarking_failed,
-                    port,
+                    server_dht_peer_info,
                     num_experts,
                     batch_size,
                     hid_dim,
                     num_batches_per_client,
                     backprop,
                 ),
+                daemon=True,
             )
             for i in range(num_clients)
         ]
 
         for client in clients:
-            client.daemon = True
             client.start()
 
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
-        # start server
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         experts = {}
         for i in range(num_experts):
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
-            experts[f"expert{i}"] = hivemind.ExpertBackend(
-                name=f"expert{i}",
+            experts[f"expert.{i}"] = hivemind.ExpertBackend(
+                name=f"expert.{i}",
                 expert=expert,
                 optimizer=torch.optim.Adam(expert.parameters()),
                 args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
@@ -114,11 +124,11 @@ def benchmark_throughput(
                 max_batch_size=max_batch_size,
             )
         timestamps["created_experts"] = time.perf_counter()
+
         server = hivemind.moe.Server(
-            None,
-            experts,
-            listen_on=f"{hivemind.LOCALHOST}:{port}",
-            num_connection_handlers=num_handlers,
+            dht=server_dht,
+            expert_backends=experts,
+            num_connection_handlers=1,  # TODO: support greater number
             device=device,
         )
         server.start()
@@ -127,8 +137,13 @@ def benchmark_throughput(
         can_start.set()
 
         for client in clients:
+            print("joining clients")
+            print(client.is_alive())
             client.join()
+            print("i'm here")
+
         timestamps["clients_finished"] = time.perf_counter()
+
     except BaseException as e:
         benchmarking_failed.set()
         raise e

+ 1 - 2
hivemind/hivemind_cli/run_server.py

@@ -17,8 +17,7 @@ def main():
     # fmt:off
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
-    parser.add_argument('--listen_on', type=str, default='0.0.0.0:*', required=False,
-                        help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
+
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
                         help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'

+ 8 - 1
hivemind/moe/__init__.py

@@ -1,2 +1,9 @@
 from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts
-from hivemind.moe.server import ExpertBackend, Server, declare_experts, get_experts, register_expert_class, ConnectionHandler
+from hivemind.moe.server import (
+    ConnectionHandler,
+    ExpertBackend,
+    Server,
+    declare_experts,
+    get_experts,
+    register_expert_class,
+)

+ 15 - 22
hivemind/moe/client/expert.py

@@ -1,32 +1,23 @@
-import asyncio
-from concurrent.futures import Future
-import multiprocessing as mp
 import pickle
-from typing import Any, Dict, Optional, Tuple, Awaitable
-from threading import Thread
+from concurrent.futures import Future
 from queue import Queue
+from threading import Thread
+from typing import Any, Awaitable, Dict, Optional, Tuple
 
 import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
-
-#from hivemind.moe.server.connection_handler import ConnectionHandler
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+import hivemind
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
-
+from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
-import hivemind
-
-#ConnectionHandlerStub = hivemind.moe.server.connection_handler.ConnectionHandler._stub_type
-
-
-def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
+def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
@@ -39,7 +30,7 @@ class RemoteExpert(nn.Module):
     :param uid: unique expert identifier
     """
 
-    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None):
+    def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
         super().__init__()
         self.uid, self.server_peer_info = uid, server_peer_info
         self._info = None
@@ -49,7 +40,8 @@ class RemoteExpert(nn.Module):
         else:
             self.p2p = p2p
 
-        _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
+        if connect:
+            _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
 
     @property
     def stub(self) -> StubBase:
@@ -66,6 +58,7 @@ class RemoteExpert(nn.Module):
 
         if not nested_compare(forward_inputs, self.info["forward_schema"]):
             raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
+
         flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
 
         # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
@@ -100,7 +93,6 @@ class _RemoteModuleCall(torch.autograd.Function):
                 except Exception as e:
                     future.set_exception(e)
                     continue
-
                 future.set_result(result)
 
         loop.run_until_complete(receive_tasks())
@@ -108,7 +100,7 @@ class _RemoteModuleCall(torch.autograd.Function):
     @classmethod
     def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
         if cls._event_thread is None:
-            cls._event_thread = Thread(target=cls._run)
+            cls._event_thread = Thread(target=cls._run, daemon=True)
             cls._event_thread.start()
 
         future = Future()
@@ -117,7 +109,8 @@ class _RemoteModuleCall(torch.autograd.Function):
         if return_future:
             return future
 
-        return future.result()
+        result = future.result()
+        return result
 
     @classmethod
     def forward(
@@ -125,7 +118,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ctx,
         dummy: torch.Tensor,
         uid: str,
-        stub,#: ConnectionHandlerStub,
+        stub,  #: ConnectionHandlerStub,
         info: Dict[str, Any],
         *inputs: torch.Tensor,
     ) -> Tuple[torch.Tensor, ...]:

+ 2 - 9
hivemind/moe/server/__init__.py

@@ -58,7 +58,6 @@ class Server(threading.Thread):
         self,
         dht: Optional[DHT],
         expert_backends: Dict[str, ExpertBackend],
-        listen_on: Endpoint = "0.0.0.0:*",
         num_connection_handlers: int = 1,
         update_period: int = 30,
         start=False,
@@ -67,9 +66,6 @@ class Server(threading.Thread):
     ):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
-        if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=get_free_port())
-        self.listen_on, self.port = listen_on, get_port(listen_on)
 
         self.conn_handlers = [ConnectionHandler(dht, self.experts) for _ in range(1)]
         if checkpoint_dir is not None:
@@ -82,7 +78,7 @@ class Server(threading.Thread):
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 dht=self.dht,
-                endpoint=self.listen_on,
+                peer_id=self.dht.peer_id,
                 update_period=self.update_period,
                 daemon=True,
             )
@@ -93,7 +89,6 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        listen_on="0.0.0.0:*",
         num_experts: int = None,
         expert_uids: str = None,
         expert_pattern: str = None,
@@ -222,7 +217,6 @@ class Server(threading.Thread):
         return cls(
             dht,
             experts,
-            listen_on=listen_on,
             num_connection_handlers=num_handlers,
             device=device,
             checkpoint_dir=checkpoint_dir,
@@ -235,8 +229,7 @@ class Server(threading.Thread):
         Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
         runs Runtime (self.runtime) to process incoming requests.
         """
-        logger.info(f"Server started at {self.listen_on}")
-        logger.info(f"Got {len(self.experts)} experts:")
+        logger.info(f"Server started with {len(self.experts)} experts:")
         for expert_name, backend in self.experts.items():
             num_parameters = sum(p.numel() for p in backend.expert.parameters() if p.requires_grad)
             logger.info(f"{expert_name}: {backend.expert.__class__.__name__}, {num_parameters} parameters")

+ 6 - 3
hivemind/moe/server/connection_handler.py

@@ -5,11 +5,11 @@ from typing import Dict
 
 import torch
 
-from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
+from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import get_logger, nested_flatten, MPFuture
+from hivemind.utils import MPFuture, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
@@ -45,6 +45,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             except Exception as e:
                 self.ready.set_exception(e)
                 return
+
         self.ready.set_result(None)
 
         try:
@@ -65,7 +66,9 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
-    async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [

+ 13 - 13
hivemind/moe/server/dht_handler.py

@@ -14,27 +14,27 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     split_uid,
 )
-from hivemind.utils import Endpoint, get_dht_time, get_port
+from hivemind.p2p import PeerID, PeerInfo
+from hivemind.utils import get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, endpoint: Endpoint, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
-        assert get_port(endpoint) is not None
-        self.endpoint = endpoint
+        self.peer_id = peer_id
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.endpoint)
+        declare_experts(self.dht, self.experts.keys(), self.peer_id)
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.endpoint)
+            declare_experts(self.dht, self.experts.keys(), self.peer_id)
 
 
 def declare_experts(
-    dht: DHT, uids: Sequence[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration = 300, wait: bool = True
+    dht: DHT, uids: Sequence[ExpertUID], peer_id: PeerID, expiration: DHTExpiration = 300, wait: bool = True
 ) -> Dict[ExpertUID, bool]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
@@ -49,22 +49,22 @@ def declare_experts(
     for uid in uids:
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), endpoint=endpoint, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=peer_id, expiration=expiration), return_future=not wait
     )
 
 
 async def _declare_experts(
-    dht: DHT, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, expiration: DHTExpiration
+    dht: DHT, node: DHTNode, uids: List[ExpertUID], peer_id: PeerID, expiration: DHTExpiration
 ) -> Dict[ExpertUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     expiration_time = get_dht_time() + expiration
     data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {}
     for uid in uids:
-        data_to_store[uid, None] = endpoint
+        data_to_store[uid, None] = peer_id.to_base58()
         prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}"
         for i in range(prefix.count(UID_DELIMITER) - 1):
             prefix, last_coord = split_uid(prefix)
-            data_to_store[prefix, last_coord] = [uid, endpoint]
+            data_to_store[prefix, last_coord] = [uid, peer_id.to_base58()]
 
     keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items()))
     store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers)
@@ -94,6 +94,6 @@ async def _get_experts(
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        if found[uid] is not None and isinstance(found[uid].value, Endpoint):
-            experts[i] = RemoteExpert(uid, found[uid].value)
+        if found[uid] is not None and isinstance(found[uid].value, PeerID):
+            experts[i] = RemoteExpert(uid, PeerInfo(peer_id=found[uid].value, addrs=[]))
     return experts