Pārlūkot izejas kodu

returned back empty lines, deleted unused grpc utils

Pavel Samygin 3 gadi atpakaļ
vecāks
revīzija
d56742b333

+ 1 - 1
hivemind/averaging/averager.py

@@ -37,8 +37,8 @@ from hivemind.utils.asyncio import (
     enter_asynchronously,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 
 # flavour types

+ 1 - 1
hivemind/moe/client/beam_search.py

@@ -385,7 +385,7 @@ class MoEBeamSearcher:
             ),
             return_future,
         )
-        print(result)
+
         if return_future:
             return RemoteExpertWorker.spawn_experts_bulk_future(result, self.dht)
         return RemoteExpertWorker.spawn_experts_bulk(result, self.dht)

+ 1 - 1
hivemind/moe/client/expert.py

@@ -26,8 +26,8 @@ from hivemind.utils import (
     nested_pack,
     switch_to_uvloop,
 )
-from hivemind.utils.grpc import gather_from_streaming, split_for_streaming
 from hivemind.utils.mpfuture import MPFuture
+from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 

+ 1 - 1
hivemind/moe/client/moe.py

@@ -101,7 +101,7 @@ class RemoteMixtureOfExperts(nn.Module):
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu().numpy() for scores in grid_scores], self.k_best
         )
-        print(chosen_experts)
+
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))

+ 4 - 0
hivemind/moe/client/switch_moe.py

@@ -80,6 +80,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
 
         # Compute scores, find most appropriate experts with beam search
         grid_scores = self.proj(input_for_gating).split_with_sizes(self.beam_search.grid_size, dim=-1)
+
         grid_dropout_masks = (
             (
                 torch.rand(size=(dim_size,), dtype=input_for_gating.dtype, device=input_for_gating.device)
@@ -95,10 +96,12 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
             )
             for grid_score, dropout_mask in zip(grid_scores, grid_dropout_masks)
         ]
+
         grid_softmax = [torch.softmax(grid_score, dim=-1) for grid_score in grid_scores_dropout]
         chosen_experts: List[List[RemoteExpert]] = self.beam_search.batch_find_best_experts(
             [scores.detach().cpu() for scores in grid_scores_dropout], self.k_best
         )
+
         if self._expert_info is None:
             try:
                 self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
@@ -109,6 +112,7 @@ class RemoteSwitchMixtureOfExperts(RemoteMixtureOfExperts):
                 )
             except P2PDaemonError as e:
                 logger.warning(f"Failed to get RemoteSwitchMixtureOfExperts.output_shape: {e}")
+
         expert_mask, *expert_outputs = _RemoteCallMany.apply(
             DUMMY,
             chosen_experts,

+ 1 - 1
hivemind/moe/server/connection_handler.py

@@ -13,7 +13,7 @@ from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
-from hivemind.utils.grpc import gather_from_streaming, split_for_streaming
+from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/moe/server/server.py

@@ -303,7 +303,7 @@ class Server(threading.Thread):
 
 @contextmanager
 def background_server(*args, shutdown_timeout=5, **kwargs) -> PeerInfo:
-    """A context manager that creates server in a background thread, awaits .ready on entry and shutdowns on exit"""
+    """A context manager that creates server in a background thread, awaits .ready on entry and shuts down on exit"""
     pipe, runners_pipe = mp.Pipe(duplex=True)
     runner = mp.Process(target=_server_runner, args=(runners_pipe, *args), kwargs=kwargs)
     try:

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,4 @@
 from hivemind.utils.asyncio import *
-from hivemind.utils.grpc import *
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
@@ -7,5 +6,6 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
+from hivemind.utils.streaming import *
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 0 - 250
hivemind/utils/grpc.py

@@ -1,250 +0,0 @@
-"""
-Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
-"""
-
-from __future__ import annotations
-
-import os
-import threading
-from typing import (
-    Any,
-    AsyncIterator,
-    Callable,
-    Dict,
-    Iterable,
-    Iterator,
-    List,
-    NamedTuple,
-    Optional,
-    Tuple,
-    Type,
-    TypeVar,
-    Union,
-)
-
-import grpc
-import torch
-
-from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger
-from hivemind.utils.networking import Endpoint
-from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration, get_dht_time
-
-logger = get_logger(__name__)
-
-Stub = TypeVar("Stub")
-
-GRPC_KEEPALIVE_OPTIONS = (
-    ("grpc.keepalive_time_ms", 60 * 1000),
-    ("grpc.keepalive_timeout_ms", 60 * 1000),
-    ("grpc.keepalive_permit_without_calls", True),
-    ("grpc.http2.max_pings_without_data", 0),
-    ("grpc.http2.min_time_between_pings_ms", 30 * 1000),
-    ("grpc.http2.min_ping_interval_without_data_ms", 10 * 1000),
-)
-
-
-class ChannelInfo(NamedTuple):
-    target: Endpoint
-    aio: bool
-    options: Tuple[Tuple[str, str], ...]
-    credentials: Optional[grpc.ChannelCredentials]
-    compression: Optional[grpc.Compression]
-
-
-class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.Channel], Dict]]):
-    """
-    A process-wide cache of gRPC channels, supports both normal and aio channels, secure/insecure channels, etc
-    Based on grpcio internal channel cache by Richard Belleville and Lidi Zheng (thanks!)
-    Unlike TimedStorage, ChannelCache actively evicts stale channels even if the cache is not accessed
-    Unlike grpc._simple_stubs.ChannelCache, this implementation supports aio and does not forcibly close active channels
-    """
-
-    MAXIMUM_CHANNELS = int(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_MAXIMUM", 4096))
-    EVICTION_PERIOD_SECONDS = float(os.environ.get("GRPC_PYTHON_MANAGED_CHANNEL_EVICTION_SECONDS", 10 * 60))
-    logger.debug(f"Eviction period = {EVICTION_PERIOD_SECONDS}s, max channels = {MAXIMUM_CHANNELS}")
-
-    _singleton: Optional[ChannelCache] = None
-    _singleton_pid: int = os.getpid()
-    _lock: threading.RLock = threading.RLock()
-    _update_eviction_evt: threading.Event = threading.Event()
-
-    def __init__(self, _created_as_singleton=False):
-        assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
-        super().__init__(maxsize=self.MAXIMUM_CHANNELS)
-        self._is_active = True
-        self._nearest_expiration_time = float("inf")
-        self._eviction_thread = threading.Thread(target=self._evict_stale_channels_in_background, daemon=True)
-        self._eviction_thread.start()
-
-    @classmethod
-    def get_singleton(cls):
-        """Get or create the channel cache for the current process"""
-        with cls._lock:
-            if cls._singleton is None or cls._singleton_pid != os.getpid():
-                if cls._singleton is not None:
-                    cls._singleton._stop_background_thread()
-                cls._singleton, cls._singleton_pid = cls(_created_as_singleton=True), os.getpid()
-            return cls._singleton
-
-    @classmethod
-    def get_stub(
-        cls,
-        target: Endpoint,
-        stub_type: Type[Stub],
-        *,
-        aio: bool,
-        options: Tuple[Tuple[str, Any]] = (),
-        channel_credentials: Optional[grpc.ChannelCredentials] = None,
-        compression: Optional[grpc.Compression] = None,
-    ) -> Stub:
-        """
-        Create a grpc channel with given options or reuse pre-existing one
-
-        :param target: the recipient's address and port
-        :param stub_type: a gRPC stub (client) to be instantiated
-        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
-        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
-        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
-        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
-        """
-        cache = cls.get_singleton()
-        with cls._lock:
-            key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
-            entry: ValueWithExpiration = super(cls, cache).get(key)
-
-            if entry is not None:
-                channel, stubs = entry.value
-            else:
-                channel = cls._create_channel(*key)
-                stubs = {}
-
-            channel._channel.check_connectivity_state(True)
-
-            if stub_type not in stubs:
-                stubs[stub_type] = stub_type(channel)
-
-            # either cache channel or update expiration of an existing channel
-            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
-            super(cls, cache).store(key, (channel, stubs), expiration_time)
-
-            if expiration_time < cache._nearest_expiration_time:
-                cache._nearest_expiration_time = expiration_time
-                cls._update_eviction_evt.set()
-
-            return stubs[stub_type]
-
-    @classmethod
-    def _create_channel(
-        cls,
-        target: Endpoint,
-        aio: bool,
-        extra_options: Tuple[Tuple[str, Any], ...],
-        channel_credentials: Optional[grpc.ChannelCredentials],
-        compression: Optional[grpc.Compression],
-    ) -> Union[grpc.Channel, grpc.aio.Channel]:
-        namespace = grpc.aio if aio else grpc
-
-        options = extra_options + GRPC_KEEPALIVE_OPTIONS
-
-        if channel_credentials is None:
-            logger.debug(
-                f"Creating insecure {namespace} channel with options '{options}' " f"and compression '{compression}'"
-            )
-            return namespace.insecure_channel(target, options=options, compression=compression)
-        else:
-            logger.debug(
-                f"Creating secure {namespace} channel with credentials '{channel_credentials}', "
-                f"options '{options}' and compression '{compression}'"
-            )
-            return namespace.secure_channel(
-                target, credentials=channel_credentials, options=options, compression=compression
-            )
-
-    def _evict_stale_channels_in_background(self):
-        while self._is_active:
-            now = get_dht_time()
-            time_to_wait = max(0.0, self._nearest_expiration_time - now)
-            interrupted_early = self._update_eviction_evt.wait(time_to_wait if time_to_wait != float("inf") else None)
-            if interrupted_early:
-                self._update_eviction_evt.clear()
-                continue
-
-            with self._lock:
-                self._remove_outdated()
-                _, entry = super().top()
-                self._nearest_expiration_time = entry.expiration_time if entry is not None else float("inf")
-
-    def _stop_background_thread(self):
-        with self._lock:
-            self._is_active = False
-            self._update_eviction_evt.set()
-
-    def store(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def get(self, *args, **kwargs) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-    def top(self) -> ValueError:
-        raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
-
-
-STREAMING_CHUNK_SIZE_BYTES = 2**16
-
-
-def split_for_streaming(
-    serialized_tensor: runtime_pb2.Tensor,
-    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
-) -> Iterator[runtime_pb2.Tensor]:
-    """Split serialized_tensor into multiple chunks for gRPC streaming"""
-    buffer = memoryview(serialized_tensor.buffer)
-    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
-    yield runtime_pb2.Tensor(
-        compression=serialized_tensor.compression,
-        buffer=buffer[:chunk_size_bytes].tobytes(),
-        chunks=num_chunks,
-        size=serialized_tensor.size,
-        dtype=serialized_tensor.dtype,
-        requires_grad=serialized_tensor.requires_grad,
-    )
-    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
-        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
-
-
-def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
-    """Restore a result of split_into_chunks into a single serialized tensor"""
-    stream = iter(stream)
-    first_chunk = next(stream)
-    serialized_tensor = runtime_pb2.Tensor()
-    serialized_tensor.CopyFrom(first_chunk)
-    buffer_chunks = [first_chunk.buffer]
-    for tensor_part in stream:
-        buffer_chunks.append(tensor_part.buffer)
-    serialized_tensor.buffer = b"".join(buffer_chunks)
-    return serialized_tensor
-
-
-StreamMessage = TypeVar("StreamMessage")
-
-
-async def gather_from_streaming(
-    stream: AsyncIterator[StreamMessage],
-    key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
-    deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
-) -> List[torch.Tensor]:
-    tensors = []
-    parts = []
-
-    async for msg in stream:
-        parts_stream = key(msg)
-        for part in parts_stream:
-            if part.dtype and parts:
-                tensors.append(deserializer(combine_from_streaming(parts)))
-                parts = []
-
-            parts.append(part)
-    if parts:
-        tensors.append(deserializer(combine_from_streaming(parts)))
-
-    return tensors

+ 4 - 0
hivemind/utils/networking.py

@@ -33,6 +33,7 @@ def strip_port(endpoint: Endpoint) -> Hostname:
 def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
     :note: Using this function is discouraged since it often leads to a race condition
            with the "Address is already in use" error if the code is run in parallel.
     """
@@ -51,13 +52,16 @@ def choose_ip_address(
     """
     Currently, some components of hivemind are not converted to work over libp2p and use classical networking.
     To allow other peers reach a server when needed, these components announce a machine's IP address.
+
     This function automatically selects the best IP address to announce among publicly visible multiaddrs
     of this machine identified by libp2p (typically, using the ``P2P.get_visible_maddrs()`` method),
     so a user does not need to define this address manually (unless the user wants to).
+
     The best IP address is chosen using the following logic:
       - Prefer IP addresses from global address blocks
         (in terms of https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Address.is_global)
       - Among the IP addresses of the same globality status, prefer IPv4 addresses over IPv6
+
     If the default logic does not suit you, it is recommended to set the announced IP address manually.
     """
 

+ 76 - 0
hivemind/utils/streaming.py

@@ -0,0 +1,76 @@
+"""
+Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
+"""
+
+from __future__ import annotations
+
+from typing import AsyncIterator, Callable, Iterable, Iterator, List, TypeVar
+
+import torch
+
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+STREAMING_CHUNK_SIZE_BYTES = 2**16
+
+
+def split_for_streaming(
+    serialized_tensor: runtime_pb2.Tensor,
+    chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES,
+) -> Iterator[runtime_pb2.Tensor]:
+    """Split serialized_tensor into multiple chunks for gRPC streaming"""
+    buffer = memoryview(serialized_tensor.buffer)
+    num_chunks = len(range(0, len(buffer), chunk_size_bytes))
+    yield runtime_pb2.Tensor(
+        compression=serialized_tensor.compression,
+        buffer=buffer[:chunk_size_bytes].tobytes(),
+        chunks=num_chunks,
+        size=serialized_tensor.size,
+        dtype=serialized_tensor.dtype,
+        requires_grad=serialized_tensor.requires_grad,
+    )
+    for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes):
+        yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes())
+
+
+def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor:
+    """Restore a result of split_into_chunks into a single serialized tensor"""
+    stream = iter(stream)
+    first_chunk = next(stream)
+    serialized_tensor = runtime_pb2.Tensor()
+    serialized_tensor.CopyFrom(first_chunk)
+    buffer_chunks = [first_chunk.buffer]
+    for tensor_part in stream:
+        buffer_chunks.append(tensor_part.buffer)
+    serialized_tensor.buffer = b"".join(buffer_chunks)
+    return serialized_tensor
+
+
+StreamMessage = TypeVar("StreamMessage")
+
+
+async def gather_from_streaming(
+    stream: AsyncIterator[StreamMessage],
+    key: Callable[[StreamMessage], Iterable[runtime_pb2.Tensor]],
+    deserializer: Callable[[runtime_pb2.Tensor], torch.Tensor],
+) -> List[torch.Tensor]:
+    """Async wrapper of combine_from_streaming allowing to work with arbitrary messages gathered from AsyncIterator"""
+
+    tensors = []
+    parts = []
+
+    async for msg in stream:
+        parts_stream = key(msg)
+        for part in parts_stream:
+            if part.dtype and parts:
+                tensors.append(deserializer(combine_from_streaming(parts)))
+                parts = []
+
+            parts.append(part)
+    if parts:
+        tensors.append(deserializer(combine_from_streaming(parts)))
+
+    return tensors

+ 0 - 44
tests/test_util_modules.py

@@ -330,50 +330,6 @@ def test_many_futures():
     p.join()
 
 
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_channel_cache():
-    hivemind.ChannelCache.MAXIMUM_CHANNELS = 3
-    hivemind.ChannelCache.EVICTION_PERIOD_SECONDS = 0.1
-
-    c1 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c2 = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c3 = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c3_again = hivemind.ChannelCache.get_stub("localhost:1338", DHTStub, aio=False)
-    c1_again = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-    c4 = hivemind.ChannelCache.get_stub("localhost:1339", DHTStub, aio=True)
-    c2_anew = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=True)
-    c1_yetagain = hivemind.ChannelCache.get_stub("localhost:1337", DHTStub, aio=False)
-
-    await asyncio.sleep(0.2)
-    c1_anew = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_anew_again = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=DHTStub)
-    c1_otherstub = hivemind.ChannelCache.get_stub(target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub)
-    await asyncio.sleep(0.05)
-    c1_otherstub_again = hivemind.ChannelCache.get_stub(
-        target="localhost:1337", aio=False, stub_type=ConnectionHandlerStub
-    )
-    all_channels = [c1, c2, c3, c4, c3_again, c1_again, c2_anew, c1_yetagain, c1_anew, c1_anew_again, c1_otherstub]
-
-    assert all(isinstance(c, DHTStub) for c in all_channels[:-1])
-    assert isinstance(all_channels[-1], ConnectionHandlerStub)
-    assert "aio" in repr(c2.rpc_find)
-    assert "aio" not in repr(c1.rpc_find)
-
-    duplicates = {
-        (c1, c1_again),
-        (c1, c1_yetagain),
-        (c1_again, c1_yetagain),
-        (c3, c3_again),
-        (c1_anew, c1_anew_again),
-        (c1_otherstub, c1_otherstub_again),
-    }
-    for i in range(len(all_channels)):
-        for j in range(i + 1, len(all_channels)):
-            ci, cj = all_channels[i], all_channels[j]
-            assert (ci is cj) == ((ci, cj) in duplicates), (i, j)
-
-
 def test_serialize_tuple():
     test_pairs = (
         ((1, 2, 3), [1, 2, 3]),