Browse Source

review issues fix

Pavel Samygin 3 years ago
parent
commit
157c14422a

+ 2 - 2
hivemind/moe/client/beam_search.py

@@ -151,7 +151,7 @@ class MoEBeamSearcher:
                 maybe_prefix_data = await pending_task
                 maybe_prefix_data = await pending_task
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                 if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                     successors = {
                     successors = {
-                        coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
+                        coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                         for coord, match in maybe_prefix_data.value.items()
                         for coord, match in maybe_prefix_data.value.items()
                         if isinstance(coord, Coordinate)
                         if isinstance(coord, Coordinate)
                         and isinstance(getattr(match, "value", None), list)
                         and isinstance(getattr(match, "value", None), list)
@@ -218,7 +218,7 @@ class MoEBeamSearcher:
         for prefix, found in dht_responses.items():
         for prefix, found in dht_responses.items():
             if found and isinstance(found.value, dict):
             if found and isinstance(found.value, dict):
                 successors[prefix] = {
                 successors[prefix] = {
-                    coord: UidEndpoint(match.value[0], PeerInfo.from_tuple(match.value[1]))
+                    coord: UidEndpoint(uid=match.value[0], peer_info=PeerInfo.from_tuple(match.value[1]))
                     for coord, match in found.value.items()
                     for coord, match in found.value.items()
                     if isinstance(coord, Coordinate)
                     if isinstance(coord, Coordinate)
                     and 0 <= coord < grid_size
                     and 0 <= coord < grid_size

+ 27 - 23
hivemind/moe/client/expert.py

@@ -24,7 +24,7 @@ from hivemind.utils import (
     nested_pack,
     nested_pack,
 )
 )
 from hivemind.utils.mpfuture import MPFuture
 from hivemind.utils.mpfuture import MPFuture
-from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
 
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
 
@@ -35,6 +35,8 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandler
 
 
 @dataclass(frozen=True)
 @dataclass(frozen=True)
 class RemoteExpertInfo:
 class RemoteExpertInfo:
+    """A simple data class containing uid of expert and server PeerInfo"""
+
     uid: str
     uid: str
     peer_info: PeerInfo
     peer_info: PeerInfo
 
 
@@ -45,7 +47,9 @@ class RemoteExpert(nn.Module):
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
     Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
     Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
-    :param uid: unique expert identifier
+
+    :param expert_info: RemoteExpertInfo with uid and server PeerInfo
+    :param p2p: P2P instance connected to the running p2pd
     """
     """
 
 
     def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
     def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
@@ -135,7 +139,7 @@ def batch_create_remote_experts(
 
 
 
 
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
 async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
 
 
     grad_inputs = await stub.rpc_backward_stream(
     grad_inputs = await stub.rpc_backward_stream(
         amap_in_executor(
         amap_in_executor(
@@ -143,8 +147,8 @@ async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Te
             iter_as_aiter(split),
             iter_as_aiter(split),
         ),
         ),
     )
     )
-
-    return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
+    return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
 
 
 
 
 async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
 async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -155,23 +159,19 @@ async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
 
 
 
 
 async def expert_backward(
 async def expert_backward(
-    uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
+    uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
-    )
-
     size = 0
     size = 0
     for t in inputs_and_grads:
     for t in inputs_and_grads:
         size += t.element_size() * t.nelement()
         size += t.element_size() * t.nelement()
-        if size >= DEFAULT_MAX_MSG_SIZE:
+        if size > DEFAULT_MAX_MSG_SIZE:
             return await _backward_stream(uid, serialized_tensors, stub)
             return await _backward_stream(uid, serialized_tensors, stub)
     else:
     else:
         return await _backward_unary(uid, serialized_tensors, stub)
         return await _backward_unary(uid, serialized_tensors, stub)
 
 
 
 
 async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
 async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
+    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
 
 
     outputs = await stub.rpc_forward_stream(
     outputs = await stub.rpc_forward_stream(
         amap_in_executor(
         amap_in_executor(
@@ -180,7 +180,8 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
         ),
         ),
     )
     )
 
 
-    return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
+    return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
 
 
 
 
 async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
 async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
@@ -190,14 +191,13 @@ async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tens
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 
 
-async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
-    serialized_tensors = (
-        serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
-    )
+async def expert_forward(
+    uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
+) -> List[torch.Tensor]:
     size = 0
     size = 0
     for t in inputs:
     for t in inputs:
         size += t.element_size() * t.nelement()
         size += t.element_size() * t.nelement()
-        if size >= DEFAULT_MAX_MSG_SIZE:
+        if size > DEFAULT_MAX_MSG_SIZE:
             return await _forward_stream(uid, serialized_tensors, stub)
             return await _forward_stream(uid, serialized_tensors, stub)
     else:
     else:
         return await _forward_unary(uid, serialized_tensors, stub)
         return await _forward_unary(uid, serialized_tensors, stub)
@@ -220,10 +220,11 @@ class _RemoteModuleCall(torch.autograd.Function):
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         inputs = tuple(tensor.cpu().detach() for tensor in inputs)
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.uid, ctx.stub, ctx.info = uid, stub, info
         ctx.save_for_backward(*inputs)
         ctx.save_for_backward(*inputs)
-
-        deserialized_outputs = _RemoteExpertWorker.run_coroutine(
-            expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         )
         )
+        deserialized_outputs = _RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
 
 
         return tuple(deserialized_outputs)
         return tuple(deserialized_outputs)
 
 
@@ -233,9 +234,12 @@ class _RemoteModuleCall(torch.autograd.Function):
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
         backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
-
+        serialized_tensors = (
+            serialize_torch_tensor(tensor, proto.compression)
+            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+        )
         deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
         deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
-            expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
+            expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
         )
         )
 
 
         return (DUMMY, None, None, None, *deserialized_grad_inputs)
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 13 - 13
hivemind/moe/client/moe.py

@@ -9,16 +9,10 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from torch.autograd.function import once_differentiable
 from torch.autograd.function import once_differentiable
 
 
+from hivemind.compression import serialize_torch_tensor
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.beam_search import MoEBeamSearcher
 from hivemind.moe.client.beam_search import MoEBeamSearcher
-from hivemind.moe.client.expert import (
-    DUMMY,
-    RemoteExpert,
-    _get_expert_stub,
-    expert_backward,
-    expert_forward,
-)
-
+from hivemind.moe.client.expert import DUMMY, RemoteExpert, _get_expert_stub, expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import _RemoteExpertWorker
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.moe.server.expert_uid import UID_DELIMITER
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError
@@ -233,10 +227,13 @@ class _RemoteCallMany(torch.autograd.Function):
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         pending_tasks: Dict[Future, Tuple[int, int]] = {}
         for i in range(num_samples):
         for i in range(num_samples):
             for j, expert in enumerate(experts_per_sample[i]):
             for j, expert in enumerate(experts_per_sample[i]):
-                compressions = (p.compression for p in nested_flatten(info["forward_schema"]))
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
                 stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
+                serialized_tensors = (
+                    serialize_torch_tensor(tensor, proto.compression)
+                    for tensor, proto in zip(flat_inputs_per_sample[i], nested_flatten(info["forward_schema"]))
+                )
                 new_task = _RemoteExpertWorker.run_coroutine(
                 new_task = _RemoteExpertWorker.run_coroutine(
-                    expert_forward(expert.uid, flat_inputs_per_sample[i], compressions, stub),
+                    expert_forward(expert.uid, flat_inputs_per_sample[i], serialized_tensors, stub),
                     return_future=True,
                     return_future=True,
                 )
                 )
                 pending_tasks[new_task] = (i, j)
                 pending_tasks[new_task] = (i, j)
@@ -326,9 +323,12 @@ class _RemoteCallMany(torch.autograd.Function):
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
             expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             stub = _get_expert_stub(expert.p2p, expert.server_peer_info)
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
             inputs_and_grad_outputs = tuple(nested_flatten((inputs_ij, grad_outputs_ij)))
-            compressions = (p.compression for p in backward_schema)
+            serialized_tensors = (
+                serialize_torch_tensor(tensor, proto.compression)
+                for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+            )
             new_task = _RemoteExpertWorker.run_coroutine(
             new_task = _RemoteExpertWorker.run_coroutine(
-                expert_backward(expert.uid, inputs_and_grad_outputs, compressions, stub), return_future=True
+                expert_backward(expert.uid, inputs_and_grad_outputs, serialized_tensors, stub), return_future=True
             )
             )
             pending_tasks[new_task] = (i, j)
             pending_tasks[new_task] = (i, j)
 
 
@@ -419,7 +419,7 @@ def _process_dispatched_task(task: Future, detect_anomalies: bool) -> Optional[T
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         logger.warning(f"Task {task} failed: {type(task.exception())}")
         return None
         return None
 
 
-    outputs = tuple(task.result())
+    outputs = task.result()
     for tensor in outputs:
     for tensor in outputs:
         if detect_anomalies and not tensor.isfinite().all():
         if detect_anomalies and not tensor.isfinite().all():
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")
             logger.error(f"Task {task} failed: output tensor contains nan/inf values")

+ 7 - 5
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
-from typing import AsyncIterator, Dict, Iterable, List, Optional, Tuple, Union
+from typing import AsyncIterator, Dict, Iterable, List, Tuple, Union
 
 
 import torch
 import torch
 
 
@@ -12,8 +12,8 @@ from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
-from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
+from hivemind.utils.asyncio import amap_in_executor, switch_to_uvloop
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -43,6 +43,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
                 self._p2p = await self.dht.replicate_p2p()
                 self._p2p = await self.dht.replicate_p2p()
                 await self.add_p2p_handlers(self._p2p, balanced=True)
                 await self.add_p2p_handlers(self._p2p, balanced=True)
 
 
+                # wait forever
                 await asyncio.Future()
                 await asyncio.Future()
 
 
             except Exception as e:
             except Exception as e:
@@ -70,11 +71,12 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             if expert_uid is None:
             if expert_uid is None:
                 expert_uid = req.uid
                 expert_uid = req.uid
             elif expert_uid != req.uid:
             elif expert_uid != req.uid:
-                raise ValueError("Expert uids differ in one reques")
+                raise ValueError("Expert uids differ in one request")
 
 
             return req.tensors
             return req.tensors
 
 
-        inputs = await gather_from_streaming(requests, _unpack, deserialize_torch_tensor)
+        tensors_stream = amap_in_executor(_unpack, requests)
+        inputs = await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
         return expert_uid, inputs
         return expert_uid, inputs
 
 
     async def _process_inputs(
     async def _process_inputs(

+ 6 - 7
hivemind/moe/server/dht_handler.py

@@ -19,18 +19,17 @@ from hivemind.utils import MPFuture, get_dht_time
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHT, peer_id: PeerID, update_period: int = 5, **kwargs):
+    def __init__(self, experts, dht: DHT, update_period: int = 5, **kwargs):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        self.peer_id = peer_id
         self.experts = experts
         self.experts = experts
         self.dht = dht
         self.dht = dht
         self.update_period = update_period
         self.update_period = update_period
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
     def run(self) -> None:
     def run(self) -> None:
-        declare_experts(self.dht, self.experts.keys(), self.peer_id)
+        declare_experts(self.dht, self.experts.keys(), self.dht.peer_id)
         while not self.stop.wait(self.update_period):
         while not self.stop.wait(self.update_period):
-            declare_experts(self.dht, self.experts.keys(), self.peer_id)
+            declare_experts(self.dht, self.experts.keys(), self.dht.peer_id)
 
 
 
 
 def declare_experts(
 def declare_experts(
@@ -97,7 +96,7 @@ async def _get_experts(
 
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
     for i, uid in enumerate(uids):
-        elem = found[uid]
-        if elem is not None and isinstance(elem.value, tuple):
-            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(elem.value))
+        expert_info_for_uid = found[uid]
+        if expert_info_for_uid is not None and isinstance(expert_info_for_uid.value, tuple):
+            experts[i] = RemoteExpertInfo(uid, PeerInfo.from_tuple(expert_info_for_uid.value))
     return experts
     return experts

+ 2 - 3
hivemind/moe/server/server.py

@@ -34,7 +34,7 @@ logger = get_logger(__name__)
 
 
 class Server(threading.Thread):
 class Server(threading.Thread):
     """
     """
-    Server allows you to host "experts" - pytorch sub-networks used by Decentralized Mixture of Experts.
+    Server allows you to host "experts" - pytorch subnetworks used by Decentralized Mixture of Experts.
     After creation, a server should be started: see Server.run or Server.run_in_background.
     After creation, a server should be started: see Server.run or Server.run_in_background.
 
 
     A working server does two things:
     A working server does two things:
@@ -75,7 +75,6 @@ class Server(threading.Thread):
             self.dht_handler_thread = DHTHandlerThread(
             self.dht_handler_thread = DHTHandlerThread(
                 experts=self.experts,
                 experts=self.experts,
                 dht=self.dht,
                 dht=self.dht,
-                peer_id=self.dht.peer_id,
                 update_period=self.update_period,
                 update_period=self.update_period,
                 daemon=True,
                 daemon=True,
             )
             )
@@ -301,7 +300,7 @@ class Server(threading.Thread):
 
 
 @contextmanager
 @contextmanager
 def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
 def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shuts down on exit"""
+    """A context manager that creates server in a background , awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:
     try:

+ 5 - 1
hivemind/utils/__init__.py

@@ -6,6 +6,10 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
-from hivemind.utils.streaming import *
+from hivemind.utils.streaming import (
+    combine_and_deserialize_from_streaming,
+    combine_from_streaming,
+    split_for_streaming,
+)
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 0 - 20
hivemind/utils/networking.py

@@ -10,26 +10,6 @@ Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://netwo
 LOCALHOST = "127.0.0.1"
 LOCALHOST = "127.0.0.1"
 
 
 
 
-def get_port(endpoint: Endpoint) -> Optional[Port]:
-    """get port or None if port is undefined"""
-    # TODO: find a standard way to get port, make sure it works in malformed ports
-    try:
-        return int(endpoint[endpoint.rindex(":") + 1 :], base=10)
-    except ValueError:  # :* or not specified
-        return None
-
-
-def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
-    assert endpoint.endswith(":*") or get_port(endpoint) is not None, endpoint
-    return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
-
-
-def strip_port(endpoint: Endpoint) -> Hostname:
-    """Removes port from the end of endpoint. If port is not specified, does nothing"""
-    maybe_port = endpoint[endpoint.rindex(":") + 1 :]
-    return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
-
-
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.

+ 16 - 18
hivemind/utils/streaming.py

@@ -1,5 +1,5 @@
 """
 """
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
+Utilities for streaming tensors
 """
 """
 
 
 from __future__ import annotations
 from __future__ import annotations
@@ -21,7 +21,7 @@ def split_for_streaming(
     serialized_tensor: runtime_pb2.Tensor,
     serialized_tensor: runtime_pb2.Tensor,
     chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
     chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
 ) -> Iterator[runtime_pb2.Tensor]:
 ) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
+    """Split serialized_tensor into multiple chunks for streaming"""
     buffer = memoryview(serialized_tensor.buffer)
     buffer = memoryview(serialized_tensor.buffer)
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     num_chunks = len(range(0, len(buffer), chunk_size_bytes))
     yield runtime_pb2.Tensor(
     yield runtime_pb2.Tensor(
@@ -52,25 +52,23 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
 StreamMessage = TypeVar("StreamMessage")
 StreamMessage = TypeVar("StreamMessage")
 
 
 
 
-async def gather_from_streaming(
-    stream: AsyncIterator[StreamMessage],
-    key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
+async def combine_and_deserialize_from_streaming(
+    stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
     deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
-    """Async wrapper of combine_from_streaming allowing to work with arbitrary messages gathered from AsyncIterator"""
+    """Async wrapper of combine_from_streaming allowing to combine tensors from async stream of parts and deserialize"""
 
 
     tensors = []
     tensors = []
-    parts = []
-
-    async for msg in stream:
-        parts_stream = key(msg)
-        for part in parts_stream:
-            if part.dtype and parts:
-                tensors.append(deserializer(combine_from_streaming(parts)))
-                parts = []
-
-            parts.append(part)
-    if parts:
-        tensors.append(deserializer(combine_from_streaming(parts)))
+    tensor_parts = []
+
+    async for parts in stream:
+        for part in parts:
+            if part.dtype and tensor_parts:
+                tensors.append(deserializer(combine_from_streaming(tensor_parts)))
+                tensor_parts = []
+
+            tensor_parts.append(part)
+    if tensor_parts:
+        tensors.append(deserializer(combine_from_streaming(tensor_parts)))
 
 
     return tensors
     return tensors