Explorar o código

Clean up imports, remove unused utils (#486)

Max Ryabinin %!s(int64=3) %!d(string=hai) anos
pai
achega
bc2cccfdb0

+ 0 - 1
benchmarks/benchmark_dht.py

@@ -3,7 +3,6 @@ import asyncio
 import random
 import random
 import time
 import time
 import uuid
 import uuid
-from logging import shutdown
 from typing import Tuple
 from typing import Tuple
 
 
 import numpy as np
 import numpy as np

+ 1 - 1
hivemind/averaging/key_manager.py

@@ -7,7 +7,7 @@ import numpy as np
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import DHTExpiration, get_logger
 
 
 GroupKey = str
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101

+ 2 - 2
hivemind/averaging/matchmaking.py

@@ -12,11 +12,11 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
-from hivemind.dht import DHT, DHTID, DHTExpiration
+from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
-from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
 from hivemind.utils.asyncio import anext, cancel_and_wait
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 1 - 1
hivemind/dht/__init__.py

@@ -15,5 +15,5 @@ The code is organized as follows:
 
 
 from hivemind.dht.dht import DHT
 from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTValue
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase

+ 1 - 1
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
+from hivemind.utils import MSGPackSerializer, get_dht_time
 
 
 DHTKey = Subkey = DHTValue = Any
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes
 BinaryDHTID = BinaryDHTValue = bytes

+ 2 - 2
hivemind/moe/client/beam_search.py

@@ -4,7 +4,7 @@ from collections import deque
 from functools import partial
 from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode
+from hivemind.dht import DHT, DHTNode
 from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.expert_uid import (
 from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
@@ -19,7 +19,7 @@ from hivemind.moe.expert_uid import (
     is_valid_uid,
     is_valid_uid,
 )
 )
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -2,7 +2,7 @@ import threading
 from functools import partial
 from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
+from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.expert_uid import (
 from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
     FLAT_EXPERT,
@@ -16,7 +16,7 @@ from hivemind.moe.expert_uid import (
     split_uid,
     split_uid,
 )
 )
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_time
+from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, DHTExpiration, MPFuture, get_dht_time
 
 
 
 
 class DHTHandlerThread(threading.Thread):
 class DHTHandlerThread(threading.Thread):

+ 1 - 1
hivemind/p2p/p2p_daemon.py

@@ -9,7 +9,7 @@ from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from datetime import datetime
 from datetime import datetime
 from importlib.resources import path
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 
 from google.protobuf.message import Message
 from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,7 +5,7 @@ Author: Kevin Mai-Husan Chia
 """
 """
 
 
 import hashlib
 import hashlib
-from typing import Any, Sequence, Tuple, Union
+from typing import Any, Sequence, Union
 
 
 import base58
 import base58
 import multihash
 import multihash

+ 1 - 1
hivemind/utils/__init__.py

@@ -3,7 +3,7 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import get_free_port, log_visible_maddrs
+from hivemind.utils.networking import log_visible_maddrs
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.streaming import combine_from_streaming, split_for_streaming
 from hivemind.utils.streaming import combine_from_streaming, split_for_streaming

+ 0 - 7
hivemind/utils/asyncio.py

@@ -78,13 +78,6 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
     return item
 
 
 
 
-async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
-    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
-    async for item in aiter:
-        return item
-    return default
-
-
 async def await_cancelled(awaitable: Awaitable) -> bool:
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
     try:
         await awaitable
         await awaitable

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -9,7 +9,7 @@ import threading
 import uuid
 import uuid
 from contextlib import nullcontext
 from contextlib import nullcontext
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 from weakref import ref
 
 
 import torch  # used for py3.7-compatible shared memory
 import torch  # used for py3.7-compatible shared memory

+ 0 - 18
hivemind/utils/networking.py

@@ -1,5 +1,3 @@
-import socket
-from contextlib import closing
 from ipaddress import ip_address
 from ipaddress import ip_address
 from typing import List, Sequence
 from typing import List, Sequence
 
 
@@ -12,22 +10,6 @@ LOCALHOST = "127.0.0.1"
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """
-    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
-
-    :note: Using this function is discouraged since it often leads to a race condition
-           with the "Address is already in use" error if the code is run in parallel.
-    """
-    try:
-        with closing(socket.socket(*params)) as sock:
-            sock.bind(("", 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception as e:
-        raise e
-
-
 def choose_ip_address(
 def choose_ip_address(
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
 ) -> str:
 ) -> str:

+ 1 - 4
hivemind/utils/streaming.py

@@ -4,7 +4,7 @@ Utilities for streaming tensors
 
 
 from __future__ import annotations
 from __future__ import annotations
 
 
-from typing import Iterable, Iterator, TypeVar
+from typing import Iterable, Iterator
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -44,6 +44,3 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
         buffer_chunks.append(tensor_part.buffer)
         buffer_chunks.append(tensor_part.buffer)
     serialized_tensor.buffer = b"".join(buffer_chunks)
     serialized_tensor.buffer = b"".join(buffer_chunks)
     return serialized_tensor
     return serialized_tensor
-
-
-StreamMessage = TypeVar("StreamMessage")

+ 1 - 2
tests/test_allreduce.py

@@ -10,8 +10,7 @@ from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor
-from hivemind.p2p import P2P, StubBase
-from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.p2p import P2P
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked

+ 0 - 4
tests/test_allreduce_fault_tolerance.py

@@ -1,14 +1,10 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import asyncio
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import AsyncIterator
 
 
 import pytest
 import pytest
-import torch
 
 
 import hivemind
 import hivemind
-from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.averager import *
 from hivemind.averaging.averager import *
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers

+ 1 - 1
tests/test_dht.py

@@ -7,9 +7,9 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
-from hivemind.utils.networking import get_free_port
 
 
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
+from test_utils.networking import get_free_port
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio

+ 2 - 1
tests/test_p2p_daemon.py

@@ -13,9 +13,10 @@ from multiaddr import Multiaddr
 
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2, test_pb2
 from hivemind.proto import dht_pb2, test_pb2
-from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
+from test_utils.networking import get_free_port
+
 
 
 def is_process_running(pid: int) -> bool:
 def is_process_running(pid: int) -> bool:
     return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
     return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0

+ 0 - 6
tests/test_util_modules.py

@@ -16,7 +16,6 @@ from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGP
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
     aenumerate,
     aenumerate,
-    afirst,
     aiter_with_timeout,
     aiter_with_timeout,
     amap_in_executor,
     amap_in_executor,
     anext,
     anext,
@@ -430,11 +429,6 @@ async def test_asyncio_utils():
     with pytest.raises(ValueError):
     with pytest.raises(ValueError):
         await asingle(as_aiter(1, 2, 3))
         await asingle(as_aiter(1, 2, 3))
 
 
-    assert await afirst(as_aiter(1)) == 1
-    assert await afirst(as_aiter()) is None
-    assert await afirst(as_aiter(), -1) == -1
-    assert await afirst(as_aiter(1, 2, 3)) == 1
-
     async def iterate_with_delays(delays):
     async def iterate_with_delays(delays):
         for i, delay in enumerate(delays):
         for i, delay in enumerate(delays):
             await asyncio.sleep(delay)
             await asyncio.sleep(delay)

+ 0 - 0
tests/test_utils/__init__.py


+ 18 - 0
tests/test_utils/networking.py

@@ -0,0 +1,18 @@
+import socket
+from contextlib import closing
+
+
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
+    try:
+        with closing(socket.socket(*params)) as sock:
+            sock.bind(("", 0))
+            sock.setsockopt(*opt)
+            return sock.getsockname()[1]
+    except Exception as e:
+        raise e

+ 2 - 1
tests/test_utils/p2p_daemon.py

@@ -10,9 +10,10 @@ from typing import NamedTuple
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 from pkg_resources import resource_filename
 from pkg_resources import resource_filename
 
 
-from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
 
+from test_utils.networking import get_free_port
+
 TIMEOUT_DURATION = 30  # seconds
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")