Browse Source

Clean up imports, remove unused utils (#486)

Max Ryabinin 3 years ago
parent
commit
bc2cccfdb0

+ 0 - 1
benchmarks/benchmark_dht.py

@@ -3,7 +3,6 @@ import asyncio
 import random
 import time
 import uuid
-from logging import shutdown
 from typing import Tuple
 
 import numpy as np

+ 1 - 1
hivemind/averaging/key_manager.py

@@ -7,7 +7,7 @@ import numpy as np
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.dht import DHT
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import DHTExpiration, get_logger
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101

+ 2 - 2
hivemind/averaging/matchmaking.py

@@ -12,11 +12,11 @@ from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
-from hivemind.dht import DHT, DHTID, DHTExpiration
+from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
-from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
+from hivemind.utils import DHTExpiration, TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)

+ 1 - 1
hivemind/dht/__init__.py

@@ -15,5 +15,5 @@ The code is organized as follows:
 
 from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTValue
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase

+ 1 - 1
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.p2p import PeerID
-from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
+from hivemind.utils import MSGPackSerializer, get_dht_time
 
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes

+ 2 - 2
hivemind/moe/client/beam_search.py

@@ -4,7 +4,7 @@ from collections import deque
 from functools import partial
 from typing import Deque, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode
+from hivemind.dht import DHT, DHTNode
 from hivemind.moe.client.expert import RemoteExpert, batch_create_remote_experts, create_remote_experts
 from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
@@ -19,7 +19,7 @@ from hivemind.moe.expert_uid import (
     is_valid_uid,
 )
 from hivemind.p2p import PeerID
-from hivemind.utils import MPFuture, ValueWithExpiration, get_dht_time, get_logger
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 

+ 2 - 2
hivemind/moe/server/dht_handler.py

@@ -2,7 +2,7 @@ import threading
 from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
-from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
+from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
 from hivemind.moe.expert_uid import (
     FLAT_EXPERT,
@@ -16,7 +16,7 @@ from hivemind.moe.expert_uid import (
     split_uid,
 )
 from hivemind.p2p import PeerID
-from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, MPFuture, get_dht_time
+from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, DHTExpiration, MPFuture, get_dht_time
 
 
 class DHTHandlerThread(threading.Thread):

+ 1 - 1
hivemind/p2p/p2p_daemon.py

@@ -9,7 +9,7 @@ from contextlib import closing, suppress
 from dataclasses import dataclass
 from datetime import datetime
 from importlib.resources import path
-from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, Type, TypeVar, Union
 
 from google.protobuf.message import Message
 from multiaddr import Multiaddr

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -5,7 +5,7 @@ Author: Kevin Mai-Husan Chia
 """
 
 import hashlib
-from typing import Any, Sequence, Tuple, Union
+from typing import Any, Sequence, Union
 
 import base58
 import multihash

+ 1 - 1
hivemind/utils/__init__.py

@@ -3,7 +3,7 @@ from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
-from hivemind.utils.networking import get_free_port, log_visible_maddrs
+from hivemind.utils.networking import log_visible_maddrs
 from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.streaming import combine_from_streaming, split_for_streaming

+ 0 - 7
hivemind/utils/asyncio.py

@@ -78,13 +78,6 @@ async def asingle(aiter: AsyncIterable[T]) -> T:
     return item
 
 
-async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
-    """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
-    async for item in aiter:
-        return item
-    return default
-
-
 async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable

+ 1 - 1
hivemind/utils/mpfuture.py

@@ -9,7 +9,7 @@ import threading
 import uuid
 from contextlib import nullcontext
 from enum import Enum, auto
-from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
 from weakref import ref
 
 import torch  # used for py3.7-compatible shared memory

+ 0 - 18
hivemind/utils/networking.py

@@ -1,5 +1,3 @@
-import socket
-from contextlib import closing
 from ipaddress import ip_address
 from typing import List, Sequence
 
@@ -12,22 +10,6 @@ LOCALHOST = "127.0.0.1"
 logger = get_logger(__name__)
 
 
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """
-    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
-
-    :note: Using this function is discouraged since it often leads to a race condition
-           with the "Address is already in use" error if the code is run in parallel.
-    """
-    try:
-        with closing(socket.socket(*params)) as sock:
-            sock.bind(("", 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception as e:
-        raise e
-
-
 def choose_ip_address(
     maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6")
 ) -> str:

+ 1 - 4
hivemind/utils/streaming.py

@@ -4,7 +4,7 @@ Utilities for streaming tensors
 
 from __future__ import annotations
 
-from typing import Iterable, Iterator, TypeVar
+from typing import Iterable, Iterator
 
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
@@ -44,6 +44,3 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
         buffer_chunks.append(tensor_part.buffer)
     serialized_tensor.buffer = b"".join(buffer_chunks)
     return serialized_tensor
-
-
-StreamMessage = TypeVar("StreamMessage")

+ 1 - 2
tests/test_allreduce.py

@@ -10,8 +10,7 @@ from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor
-from hivemind.p2p import P2P, StubBase
-from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.p2p import P2P
 
 
 @pytest.mark.forked

+ 0 - 4
tests/test_allreduce_fault_tolerance.py

@@ -1,14 +1,10 @@
 from __future__ import annotations
 
-import asyncio
 from enum import Enum, auto
-from typing import AsyncIterator
 
 import pytest
-import torch
 
 import hivemind
-from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.averager import *
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers

+ 1 - 1
tests/test_dht.py

@@ -7,9 +7,9 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
-from hivemind.utils.networking import get_free_port
 
 from test_utils.dht_swarms import launch_dht_instances
+from test_utils.networking import get_free_port
 
 
 @pytest.mark.asyncio

+ 2 - 1
tests/test_p2p_daemon.py

@@ -13,9 +13,10 @@ from multiaddr import Multiaddr
 
 from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2, test_pb2
-from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
+from test_utils.networking import get_free_port
+
 
 def is_process_running(pid: int) -> bool:
     return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0

+ 0 - 6
tests/test_util_modules.py

@@ -16,7 +16,6 @@ from hivemind.utils import BatchTensorDescriptor, DHTExpiration, HeapEntry, MSGP
 from hivemind.utils.asyncio import (
     achain,
     aenumerate,
-    afirst,
     aiter_with_timeout,
     amap_in_executor,
     anext,
@@ -430,11 +429,6 @@ async def test_asyncio_utils():
     with pytest.raises(ValueError):
         await asingle(as_aiter(1, 2, 3))
 
-    assert await afirst(as_aiter(1)) == 1
-    assert await afirst(as_aiter()) is None
-    assert await afirst(as_aiter(), -1) == -1
-    assert await afirst(as_aiter(1, 2, 3)) == 1
-
     async def iterate_with_delays(delays):
         for i, delay in enumerate(delays):
             await asyncio.sleep(delay)

+ 0 - 0
tests/test_utils/__init__.py


+ 18 - 0
tests/test_utils/networking.py

@@ -0,0 +1,18 @@
+import socket
+from contextlib import closing
+
+
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """
+    Finds a tcp port that can be occupied with a socket with *params and use *opt options.
+
+    :note: Using this function is discouraged since it often leads to a race condition
+           with the "Address is already in use" error if the code is run in parallel.
+    """
+    try:
+        with closing(socket.socket(*params)) as sock:
+            sock.bind(("", 0))
+            sock.setsockopt(*opt)
+            return sock.getsockname()[1]
+    except Exception as e:
+        raise e

+ 2 - 1
tests/test_utils/p2p_daemon.py

@@ -10,9 +10,10 @@ from typing import NamedTuple
 from multiaddr import Multiaddr, protocols
 from pkg_resources import resource_filename
 
-from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
+from test_utils.networking import get_free_port
+
 TIMEOUT_DURATION = 30  # seconds
 P2PD_PATH = resource_filename("hivemind", "hivemind_cli/p2pd")