Răsfoiți Sursa

review issues fix

Pavel Samygin 3 ani în urmă
părinte
comite
157c14422a

+ 2 - 2
hivemind/moe/client/beam_search.py

@@ -151,7 +151,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
-                        coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
+                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
@@ -218,7 +218,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
-                    coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
+                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size

+ 27 - 23
hivemind/moe/client/expert.py

@@ -24,7 +24,7 @@ from hivemind.utils import (
     nested_pack,
 )
 from hivemind.utils.mpfuture import MPFuture
-from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -35,6 +35,8 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandler
 
 @dataclass(frozen=True)
 class RemoteExpertInfo:
+    """A simple data class containing uid of expert and server PeerInfo"""
+
     uid: str
     peer_info: PeerInfo
 
@@ -45,7 +47,9 @@ class RemoteExpert(nn.Module):
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
-    :param uid: unique expert identifier
+
+    :param expert_info: RemoteExpertInfo with uid and server PeerInfo
+    :param p2p: P2P instance connected to the running p2pd
     """
 
     def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
@@ -135,7 +139,7 @@ def batch_create_remote_experts(
 
 
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
 
     grad_inputs = await stub.rpc_backward_stream(
         amap_in_executor(
@@ -143,8 +147,8 @@ async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Te
             iter_as_aiter(split),
         ),
     )
-
-    return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
 
 
 async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -155,23 +159,19 @@ async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
 
 
 async def expert_backward(
-    uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
 ) -> List[torch.Tensor]:
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
-    )
-
     size = 0
     for t in inputs_and_grads:
         size += t.element_size() * t.nelement()
-        if size >= DEFAULT_MAX_MSG_SIZE:
+        if size > DEFAULT_MAX_MSG_SIZE:
             return await _backward_stream(uid, serialized_tensors, stub)
     else:
         return await _backward_unary(uid, serialized_tensors, stub)
 
 
 async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
 
     outputs = await stub.rpc_forward_stream(
         amap_in_executor(
@@ -180,7 +180,8 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
         ),
     )
 
-    return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
 
 
 async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -190,14 +191,13 @@ async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tens
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
-async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
-    )
+async def expert_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
     size = 0
     for t in inputs:
         size += t.element_size() * t.nelement()
-        if size >= DEFAULT_MAX_MSG_SIZE:
+        if size > DEFAULT_MAX_MSG_SIZE:
             return await _forward_stream(uid, serialized_tensors, stub)
     else:
         return await _forward_unary(uid, serialized_tensors, stub)
@@ -220,10 +220,11 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
-
-        deserialized_outputs = _RemoteExpertWorker.run_coroutine(
-            expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         )
+        deserialized_outputs = _RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
         return tuple(deserialized_outputs)
 
@@ -233,9 +234,12 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
         deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
-            expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
+            expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
         )
 
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 13 - 13
hivemind/moe/client/moe.py

@@ -9,16 +9,10 @@ import torch
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 
+from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import (
-    DUMMY,
-    RemoteExpert,
-    _get_expert_stub,
-    expert_backward,
-    expert_forward,
-)
-
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub, expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
@@ -233,10 +227,13 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
-                compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
+                serialized_tensors = (
+                    serialize_torch_tensor(tensor, proto.compression)
+                    for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
+                )
                 new_task = _RemoteExpertWorker.run_coroutine(
-                    expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
                     return_future=True,
                 )
                 pending_tasks[new_task] = (i, j)
@@ -326,9 +323,12 @@ class _RemoteCallMany(torch.autograd.Function):
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            compressions = (p.compression for p in backward_schema)
+            serialized_tensors = (
+                serialize_torch_tensor(tensor, proto.compression)
+                for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+            )
             new_task = _RemoteExpertWorker.run_coroutine(
-                expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
+                expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
             )
             pending_tasks[new_task] = (i, j)
 
@@ -419,7 +419,7 @@ def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[T
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
 
-    outputs = tuple(task.result())
+    outputs = task.result()
     for tensor in outputs:
         if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")

+ 7 - 5
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import multiprocessing as mp
-from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
+from typing import AsyncIterator, Dict, Iterable, List, Tuple, Union
 
 import torch
 
@@ -12,8 +12,8 @@ from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
-from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
+from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)
@@ -43,6 +43,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
                 self._p2p = await self.dht.replicate_p2p()
                 await self.add_p2p_handlers(self._p2p, balanced=True)
 
+                # wait forever
                 await asyncio.Future()
 
             except Exception as e:
@@ -70,11 +71,12 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             if expert_uid is None:
                 expert_uid = req.uid
             elif expert_uid != req.uid:
-                raise ValueError("Expert uids differ in one reques")
+                raise ValueError("Expert uids differ in one request")
 
             return req.tensors
 
-        inputs = await gather_from_streaming(requests, _unpack, deserialize_torch_tensor)
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
         return expert_uid, inputs
 
     async def _process_inputs(

+ 6 - 7
hivemind/moe/server/dht_handler.py

@@ -19,18 +19,17 @@ from hivemind.utils import MPFuture, get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
-        self.peer_id = peer_id
         self.experts = experts
         self.dht = dht
         self.update_period = update_period
         self.stop = threading.Event()
 
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.peer_id)
+        declare_experts(self.dht, self.experts.keys(), self.dht.peer_id)
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.peer_id)
+            declare_experts(self.dht, self.experts.keys(), self.dht.peer_id)
 
 
 def declare_experts(
@@ -97,7 +96,7 @@ async def _get_experts(
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        elem = found[uid]
-        if elem is not None and isinstance(elem.value, tuple):
-            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(elem.value))
+        expert_info_for_uid = found[uid]
+        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
     return experts

+ 2 - 3
hivemind/moe/server/server.py

@@ -34,7 +34,7 @@ logger = get_logger(__name__)
 
 class Server(threading.Thread):
     """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
+    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
     A working server does two things:
@@ -75,7 +75,6 @@ class Server(threading.Thread):
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 dht=self.dht,
-                peer_id=self.dht.peer_id,
                 update_period=self.update_period,
                 daemon=True,
             )
@@ -301,7 +300,7 @@ class Server(threading.Thread):
 
 @contextmanager
 def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shuts down on exit"""
+    """A context manager that creates server in a background , awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:

+ 5 - 1
hivemind/utils/__init__.py

@@ -6,6 +6,10 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.streaming import *
+from hivemind.utils.streaming import (
+    combine_and_deserialize_from_streaming,
+    combine_from_streaming,
+    split_for_streaming,
+)
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 0 - 20
hivemind/utils/networking.py

@@ -10,26 +10,6 @@ Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://netwo
 LOCALHOST = "127.0.0.1"
 
 
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.

+ 16 - 18
hivemind/utils/streaming.py

@@ -1,5 +1,5 @@
 """
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
+Utilities for streaming tensors
 """
 
 from __future__ import annotations
@@ -21,7 +21,7 @@ def split_for_streaming(
     serialized_tensor: runtime_pb2.Tensor,
     chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
 ) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
+    """Split serialized_tensor into multiple chunks for streaming"""
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     yield runtime_pb2.Tensor(
@@ -52,25 +52,23 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
 StreamMessage = TypeVar("StreamMessage")
 
 
-async def gather_from_streaming(
-    stream: AsyncIterator[StreamMessage],
-    key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
+async def combine_and_deserialize_from_streaming(
+    stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
 ) -> List[torch.Tensor]:
-    """Async wrapper of combine_from_streaming allowing to work with arbitrary messages gathered from AsyncIterator"""
+    """Async wrapper of combine_from_streaming allowing to combine tensors from async stream of parts and deserialize"""
 
     tensors = []
-    parts = []
-
-    async for msg in stream:
-        parts_stream = key(msg)
-        for part in parts_stream:
-            if part.dtype and parts:
-                tensors.append(deserializer(combine_from_streaming(parts)))
-                parts = []
-
-            parts.append(part)
-    if parts:
-        tensors.append(deserializer(combine_from_streaming(parts)))
+    tensor_parts = []
+
+    async for parts in stream:
+        for part in parts:
+            if part.dtype and tensor_parts:
+                tensors.append(deserializer(combine_from_streaming(tensor_parts)))
+                tensor_parts = []
+
+            tensor_parts.append(part)
+    if tensor_parts:
+        tensors.append(deserializer(combine_from_streaming(tensor_parts)))
 
     return tensors