Browse Source

new fixes

Michael Diskin 3 years ago
parent
commit
fe885ecc57

+ 1 - 1
hivemind/averaging/load_balancing.py

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
 LOAD_BALANCING_LP_DECIMALS = 9
 LOAD_BALANCING_LP_DECIMALS = 9
 
 
 
 
-def load_balance_peers(vector_size, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+def load_balance_peers(vector_size: int, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
     """
     """
     Find an optimal partitioning of weights for butterfly all-reduce given peer bandwidths.
     Find an optimal partitioning of weights for butterfly all-reduce given peer bandwidths.
     :param vector_size: total size of the averaged vector (in elements, not bytes)
     :param vector_size: total size of the averaged vector (in elements, not bytes)

+ 2 - 2
hivemind/averaging/partition.py

@@ -140,7 +140,7 @@ class TensorPartContainer:
         self._outputs_consumed = True
         self._outputs_consumed = True
         peer_index = num_parts_processed = 0
         peer_index = num_parts_processed = 0
         for tensor_index in range(len(self.local_tensors)):
         for tensor_index in range(len(self.local_tensors)):
-            tensor_parts: List[torch.Tensor] = []
+            tensor_parts = []
             while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
             while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
                 if num_parts_processed >= self.num_parts_by_peer[peer_index]:
                 if num_parts_processed >= self.num_parts_by_peer[peer_index]:
                     num_parts_processed = 0
                     num_parts_processed = 0
@@ -187,7 +187,7 @@ class TensorPartReducer:
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
-        self.accumulator: Optional[torch.Tensor] = None  # contains the sum of current tensor part from group peers
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
         self.finished = asyncio.Event()

+ 3 - 3
hivemind/dht/traverse.py

@@ -75,9 +75,9 @@ async def traverse_dht(
     num_workers: int,
     num_workers: int,
     queries_per_call: int,
     queries_per_call: int,
     get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
     get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
-    visited_nodes: Dict[DHTID, Set[DHTID]],
     found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
     found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
     await_all_tasks: bool = True,
     await_all_tasks: bool = True,
+    visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
     """
     """
     Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
     Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
@@ -109,7 +109,6 @@ async def traverse_dht(
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
 
 
     :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
     :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
-
     :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
     :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
         for reference. That function implements a special case of traverse_dht with a single query and one worker.
         for reference. That function implements a special case of traverse_dht with a single query and one worker.
 
 
@@ -118,12 +117,13 @@ async def traverse_dht(
         visited nodes: { query -> a set of all nodes that received requests for a given query }
         visited nodes: { query -> a set of all nodes that received requests for a given query }
     """
     """
     if len(queries) == 0:
     if len(queries) == 0:
-        return {}, visited_nodes
+        return {}, dict(visited_nodes or {})
 
 
     unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
     unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
     candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
     candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
     nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
     nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
     known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
     known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
+    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes or {})  # nodes that were chosen for get_neighbors call
     pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
     pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
     active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
     active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 
 

+ 2 - 2
hivemind/dht/validation.py

@@ -1,6 +1,6 @@
 import dataclasses
 import dataclasses
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Iterable, List
+from typing import Iterable
 
 
 
 
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
@@ -91,7 +91,7 @@ class RecordValidatorBase(ABC):
 
 
 class CompositeValidator(RecordValidatorBase):
 class CompositeValidator(RecordValidatorBase):
     def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
     def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
-        self._validators: List[RecordValidatorBase] = []
+        self._validators = []
         self.extend(validators)
         self.extend(validators)
 
 
     def extend(self, validators: Iterable[RecordValidatorBase]) -> None:
     def extend(self, validators: Iterable[RecordValidatorBase]) -> None:

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -57,7 +57,7 @@ class DaemonConnector:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
 
-    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+    async def open_persistent_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
         """
         """
         Open connection to daemon and upgrade it to a persistent one
         Open connection to daemon and upgrade it to a persistent one
         """
         """

+ 5 - 5
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -55,13 +55,13 @@ class PeerID:
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"<libp2p.peer.id.ID ({self.to_base58()})>"
         return f"<libp2p.peer.id.ID ({self.to_base58()})>"
 
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.to_base58()
         return self.to_base58()
 
 
-    def pretty(self):
+    def pretty(self) -> str:
         return self.to_base58()
         return self.to_base58()
 
 
-    def to_string(self):
+    def to_string(self) -> str:
         return self.to_base58()
         return self.to_base58()
 
 
     def __eq__(self, other: object) -> bool:
     def __eq__(self, other: object) -> bool:
@@ -128,10 +128,10 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
         return PeerInfo(peer_id, addrs)
 
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
         return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
 
 
 
 

+ 1 - 1
hivemind/utils/asyncio.py

@@ -111,7 +111,7 @@ async def amap_in_executor(
 ) -> AsyncIterator[T]:
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
-    queue: asyncio.Queue[Optional[Awaitable[T]]] = asyncio.Queue(max_prefetch)
+    queue = asyncio.Queue(max_prefetch)
 
 
     async def _put_items():
     async def _put_items():
         try:
         try:

+ 1 - 1
hivemind/utils/auth.py

@@ -63,7 +63,7 @@ class TokenAuthorizerBase(AuthorizerBase):
         self._local_access_token = None
         self._local_access_token = None
         self._refresh_lock = asyncio.Lock()
         self._refresh_lock = asyncio.Lock()
 
 
-        self._recent_nonces: TimedStorage[bytes, None] = TimedStorage()
+        self._recent_nonces = TimedStorage()
 
 
     @abstractmethod
     @abstractmethod
     async def get_token(self) -> AccessToken:
     async def get_token(self) -> AccessToken:

+ 8 - 7
hivemind/utils/crypto.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import base64
 import base64
 import threading
 import threading
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 
 from cryptography import exceptions
 from cryptography import exceptions
 from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives import hashes, serialization
@@ -39,7 +40,7 @@ _RSA_HASH_ALGORITHM = hashes.SHA256()
 
 
 
 
 class RSAPrivateKey(PrivateKey):
 class RSAPrivateKey(PrivateKey):
-    def __init__(self):
+    def __init__(self) -> None:
         self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
         self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
 
 
     _process_wide_key = None
     _process_wide_key = None
@@ -60,7 +61,7 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
         return RSAPublicKey(self._private_key.public_key())
 
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         state = self.__dict__.copy()
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
         # Serializes the private key to make the class instances picklable
         state["_private_key"] = self._private_key.private_bytes(
         state["_private_key"] = self._private_key.private_bytes(
@@ -70,13 +71,13 @@ class RSAPrivateKey(PrivateKey):
         )
         )
         return state
         return state
 
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self.__dict__.update(state)
         self.__dict__.update(state)
         self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
         self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
 
 
 
 
 class RSAPublicKey(PublicKey):
 class RSAPublicKey(PublicKey):
-    def __init__(self, public_key: rsa.RSAPublicKey):
+    def __init__(self, public_key: rsa.RSAPublicKey) -> None:
         self._public_key = public_key
         self._public_key = public_key
 
 
     def verify(self, data: bytes, signature: bytes) -> bool:
     def verify(self, data: bytes, signature: bytes) -> bool:
@@ -97,7 +98,7 @@ class RSAPublicKey(PublicKey):
 
 
     @classmethod
     @classmethod
     def from_bytes(cls, key: bytes) -> RSAPublicKey:
     def from_bytes(cls, key: bytes) -> RSAPublicKey:
-        loaded_key = serialization.load_ssh_public_key(key)
-        if not isinstance(loaded_key, rsa.RSAPublicKey):
+        deserialized_key = serialization.load_ssh_public_key(key)
+        if not isinstance(deserialized_key, rsa.RSAPublicKey):
             raise ValueError(f"Expected an RSA public key, got {str(key)}")
             raise ValueError(f"Expected an RSA public key, got {str(key)}")
-        return cls(loaded_key)
+        return cls(deserialized_key)

+ 3 - 3
hivemind/utils/grpc.py

@@ -54,7 +54,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     _lock: threading.RLock = threading.RLock()
     _lock: threading.RLock = threading.RLock()
     _update_eviction_evt: threading.Event = threading.Event()
     _update_eviction_evt: threading.Event = threading.Event()
 
 
-    def __init__(self, _created_as_singleton=False):
+    def __init__(self, _created_as_singleton=False) -> None:
         assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
         assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
         super().__init__(maxsize=self.MAXIMUM_CHANNELS)
         super().__init__(maxsize=self.MAXIMUM_CHANNELS)
         self._is_active = True
         self._is_active = True
@@ -76,7 +76,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     def get_stub(
     def get_stub(
         cls,
         cls,
         target: Endpoint,
         target: Endpoint,
-        stub_type: Type[Stub],
+        stub_type: type,
         *,
         *,
         aio: bool,
         aio: bool,
         options: Tuple[Tuple[str, Any]] = (),
         options: Tuple[Tuple[str, Any]] = (),
@@ -146,7 +146,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
                 target, credentials=channel_credentials, options=options, compression=compression
                 target, credentials=channel_credentials, options=options, compression=compression
             )
             )
 
 
-    def _evict_stale_channels_in_background(self):
+    def _evict_stale_channels_in_background(self) -> None:
         while self._is_active:
         while self._is_active:
             now = get_dht_time()
             now = get_dht_time()
             time_to_wait = max(0.0, self._nearest_expiration_time - now)
             time_to_wait = max(0.0, self._nearest_expiration_time - now)

+ 1 - 1
hivemind/utils/limits.py

@@ -3,7 +3,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+def increase_file_limit(new_soft: int = 2 ** 15, new_hard: int = 2 ** 15) -> None:
     """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     try:
     try:
         import resource  # local import to avoid ImportError for Windows users
         import resource  # local import to avoid ImportError for Windows users

+ 1 - 1
hivemind/utils/logging.py

@@ -83,7 +83,7 @@ class CustomFormatter(logging.Formatter):
         return super().format(record)
         return super().format(record)
 
 
 
 
-def _initialize_if_necessary():
+def _initialize_if_necessary() -> None:
     global _current_mode, _default_handler
     global _current_mode, _default_handler
 
 
     with _init_lock:
     with _init_lock:

+ 15 - 16
hivemind/utils/mpfuture.py

@@ -3,7 +3,6 @@ from __future__ import annotations
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
 import multiprocessing as mp
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
@@ -131,13 +130,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
             self._set_event_threadsafe()
             self._set_event_threadsafe()
 
 
-    def _set_event_threadsafe(self):
+    def _set_event_threadsafe(self) -> None:
         try:
         try:
             running_loop = asyncio.get_running_loop()
             running_loop = asyncio.get_running_loop()
         except RuntimeError:
         except RuntimeError:
             running_loop = None
             running_loop = None
 
 
-        async def _event_setter():
+        async def _event_setter() -> None:
             self._aio_event.set()
             self._aio_event.set()
 
 
         if self._loop.is_closed():
         if self._loop.is_closed():
@@ -150,7 +149,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._loop.run_until_complete(_event_setter())
             self._loop.run_until_complete(_event_setter())
 
 
     @classmethod
     @classmethod
-    def _initialize_mpfuture_backend(cls):
+    def _initialize_mpfuture_backend(cls) -> None:
         pid = os.getpid()
         pid = os.getpid()
         logger.debug(f"Initializing MPFuture backend for pid {pid}")
         logger.debug(f"Initializing MPFuture backend for pid {pid}")
 
 
@@ -162,7 +161,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         cls._pipe_waiter_thread.start()
         cls._pipe_waiter_thread.start()
 
 
     @staticmethod
     @staticmethod
-    def reset_backend():
+    def reset_backend() -> None:
         """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
         """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
         MPFuture._active_pid = None
         MPFuture._active_pid = None
         MPFuture._initialization_lock = mp.Lock()
         MPFuture._initialization_lock = mp.Lock()
@@ -200,7 +199,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             except Exception as e:
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
-    def _send_update(self, update_type: UpdateType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None) -> None:
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -208,7 +207,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
         except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
 
 
-    def set_result(self, result: ResultType):
+    def set_result(self, result: ResultType) -> None:
         if os.getpid() == self._origin_pid:
         if os.getpid() == self._origin_pid:
             super().set_result(result)
             super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
@@ -218,7 +217,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._state_cache[self._state], self._result = base.FINISHED, result
             self._state_cache[self._state], self._result = base.FINISHED, result
             self._send_update(UpdateType.RESULT, result)
             self._send_update(UpdateType.RESULT, result)
 
 
-    def set_exception(self, exception: Optional[BaseException]):
+    def set_exception(self, exception: Optional[BaseException]) -> None:
         if os.getpid() == self._origin_pid:
         if os.getpid() == self._origin_pid:
             super().set_exception(exception)
             super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
@@ -239,7 +238,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._send_update(UpdateType.CANCEL)
             self._send_update(UpdateType.CANCEL)
             return True
             return True
 
 
-    def set_running_or_notify_cancel(self):
+    def set_running_or_notify_cancel(self) -> bool:
         if self._state == base.PENDING:
         if self._state == base.PENDING:
             self._state = base.RUNNING
             self._state = base.RUNNING
             return True
             return True
@@ -274,18 +273,18 @@ class MPFuture(base.Future, Generic[ResultType]):
     def done(self) -> bool:
     def done(self) -> bool:
         return self._state in TERMINAL_STATES
         return self._state in TERMINAL_STATES
 
 
-    def running(self):
+    def running(self) -> bool:
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
-    def cancelled(self):
+    def cancelled(self) -> bool:
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
-    def add_done_callback(self, callback: Callable[[MPFuture], None]):
+    def add_done_callback(self, callback: Callable[[MPFuture], None]) -> None:
         if os.getpid() != self._origin_pid:
         if os.getpid() != self._origin_pid:
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
         return super().add_done_callback(callback)
 
 
-    def __await__(self):
+    def __await__(self) -> Any:
         if not self._aio_event:
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
         yield from self._aio_event.wait().__await__()
         yield from self._aio_event.wait().__await__()
@@ -294,13 +293,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         except base.CancelledError:
         except base.CancelledError:
             raise asyncio.CancelledError()
             raise asyncio.CancelledError()
 
 
-    def __del__(self):
+    def __del__(self) -> None:
         if getattr(self, "_origin_pid", None) == os.getpid():
         if getattr(self, "_origin_pid", None) == os.getpid():
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
         if getattr(self, "_aio_event", None):
         if getattr(self, "_aio_event", None):
             self._aio_event.set()
             self._aio_event.set()
 
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return dict(
         return dict(
             _sender_pipe=self._sender_pipe,
             _sender_pipe=self._sender_pipe,
             _shared_state_code=self._shared_state_code,
             _shared_state_code=self._shared_state_code,
@@ -311,7 +310,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             _exception=self._exception,
             _exception=self._exception,
         )
         )
 
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self._sender_pipe = state["_sender_pipe"]
         self._sender_pipe = state["_sender_pipe"]
         self._shared_state_code = state["_shared_state_code"]
         self._shared_state_code = state["_shared_state_code"]
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]

+ 6 - 5
hivemind/utils/nested.py

@@ -1,7 +1,8 @@
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
+from typing import Any
 
 
 
 
-def nested_compare(t, u):
+def nested_compare(t: Any, u: Any) -> bool:
     """
     """
     Return whether nested structure of t1 and t2 matches.
     Return whether nested structure of t1 and t2 matches.
     """
     """
@@ -29,7 +30,7 @@ def nested_compare(t, u):
         return True
         return True
 
 
 
 
-def nested_flatten(t):
+def nested_flatten(t: Any) -> Any:
     """
     """
     Turn nested list/tuple/dict into a flat iterator.
     Turn nested list/tuple/dict into a flat iterator.
     """
     """
@@ -43,7 +44,7 @@ def nested_flatten(t):
         yield t
         yield t
 
 
 
 
-def nested_pack(flat, structure):
+def nested_pack(flat: Any, structure: Any) -> Any:
     """
     """
     Restore nested structure from flattened state
     Restore nested structure from flattened state
     :param flat: result of nested_flatten
     :param flat: result of nested_flatten
@@ -53,7 +54,7 @@ def nested_pack(flat, structure):
     return _nested_pack(iter(flat), structure)
     return _nested_pack(iter(flat), structure)
 
 
 
 
-def _nested_pack(flat_iter, structure):
+def _nested_pack(flat_iter: Any, structure: Any) -> Any:
     if is_namedtuple(structure):
     if is_namedtuple(structure):
         return type(structure)(*[_nested_pack(flat_iter, x) for x in structure])
         return type(structure)(*[_nested_pack(flat_iter, x) for x in structure])
     elif isinstance(structure, (list, tuple)):
     elif isinstance(structure, (list, tuple)):
@@ -64,7 +65,7 @@ def _nested_pack(flat_iter, structure):
         return next(flat_iter)
         return next(flat_iter)
 
 
 
 
-def is_namedtuple(x):
+def is_namedtuple(x) -> bool:
     """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
     """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
     t = type(x)
     t = type(x)
     b = t.__bases__
     b = t.__bases__

+ 5 - 2
hivemind/utils/networking.py

@@ -5,7 +5,8 @@ from typing import Optional, Sequence
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-Hostname, Port = str, int  # flavour types
+Hostname = str
+Port = int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 LOCALHOST = "127.0.0.1"
 
 
@@ -30,7 +31,9 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
 
 
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(
+    params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+) -> Port:
     """
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
 
 

+ 2 - 2
hivemind/utils/performance_ema.py

@@ -37,7 +37,7 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
         return self.samples_per_second
 
 
-    def reset_timer(self):
+    def reset_timer(self) -> None:
         """Reset the time since the last update so that the next task performance is counted from current time"""
         """Reset the time since the last update so that the next task performance is counted from current time"""
         self.timestamp = time.perf_counter()
         self.timestamp = time.perf_counter()
 
 
@@ -51,7 +51,7 @@ class PerformanceEMA:
             self.paused = was_paused
             self.paused = was_paused
             self.reset_timer()
             self.reset_timer()
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
 
 
     @contextmanager
     @contextmanager

+ 4 - 4
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
 """ A unified interface for several common serialization methods """
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Any, Dict
+from typing import Any, Callable, Dict
 
 
 import msgpack
 import msgpack
 
 
@@ -27,10 +27,10 @@ class MSGPackSerializer(SerializerBase):
     _TUPLE_EXT_TYPE_CODE = 0x40
     _TUPLE_EXT_TYPE_CODE = 0x40
 
 
     @classmethod
     @classmethod
-    def ext_serializable(cls, type_code: int):
+    def ext_serializable(cls, type_code: int) -> Callable[[type], type]:
         assert isinstance(type_code, int), "Please specify a (unique) int type code"
         assert isinstance(type_code, int), "Please specify a (unique) int type code"
 
 
-        def wrap(wrapped_type: type):
+        def wrap(wrapped_type: type) -> type:
             assert callable(getattr(wrapped_type, "packb", None)) and callable(
             assert callable(getattr(wrapped_type, "packb", None)) and callable(
                 getattr(wrapped_type, "unpackb", None)
                 getattr(wrapped_type, "unpackb", None)
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
@@ -42,7 +42,7 @@ class MSGPackSerializer(SerializerBase):
         return wrap
         return wrap
 
 
     @classmethod
     @classmethod
-    def _encode_ext_types(cls, obj):
+    def _encode_ext_types(cls, obj: object):
         type_code = cls._ext_types.get(type(obj))
         type_code = cls._ext_types.get(type(obj))
         if type_code is not None:
         if type_code is not None:
             return msgpack.ExtType(type_code, obj.packb())
             return msgpack.ExtType(type_code, obj.packb())

+ 12 - 12
hivemind/utils/timed_storage.py

@@ -5,7 +5,7 @@ import heapq
 import time
 import time
 from contextlib import contextmanager
 from contextlib import contextmanager
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
+from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
 
 
 KeyType = TypeVar("KeyType")
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")
 ValueType = TypeVar("ValueType")
@@ -20,10 +20,10 @@ class ValueWithExpiration(Generic[ValueType]):
     value: ValueType
     value: ValueType
     expiration_time: DHTExpiration
     expiration_time: DHTExpiration
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Union[ValueType, DHTExpiration]]:
         return iter((self.value, self.expiration_time))
         return iter((self.value, self.expiration_time))
 
 
-    def __getitem__(self, item):
+    def __getitem__(self, item: Union[int, str]) -> Union[ValueType, DHTExpiration]:
         if item == 0:
         if item == 0:
             return self.value
             return self.value
         elif item == 1:
         elif item == 1:
@@ -31,7 +31,7 @@ class ValueWithExpiration(Generic[ValueType]):
         else:
         else:
             return getattr(self, item)
             return getattr(self, item)
 
 
-    def __eq__(self, item):
+    def __eq__(self, item: Any) -> bool:
         if isinstance(item, ValueWithExpiration):
         if isinstance(item, ValueWithExpiration):
             return self.value == item.value and self.expiration_time == item.expiration_time
             return self.value == item.value and self.expiration_time == item.expiration_time
         elif isinstance(item, tuple):
         elif isinstance(item, tuple):
@@ -51,13 +51,13 @@ class TimedStorage(Generic[KeyType, ValueType]):
 
 
     frozen = False  # can be set to True. If true, do not remove outdated elements
     frozen = False  # can be set to True. If true, do not remove outdated elements
 
 
-    def __init__(self, maxsize: Optional[int] = None):
+    def __init__(self, maxsize: Optional[int] = None) -> None:
         self.maxsize = maxsize or float("inf")
         self.maxsize = maxsize or float("inf")
         self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
         self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
         self.expiration_heap: List[HeapEntry[KeyType]] = []
         self.expiration_heap: List[HeapEntry[KeyType]] = []
         self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
         self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
 
 
-    def _remove_outdated(self):
+    def _remove_outdated(self) -> None:
         while (
         while (
             not self.frozen
             not self.frozen
             and self.expiration_heap
             and self.expiration_heap
@@ -108,28 +108,28 @@ class TimedStorage(Generic[KeyType, ValueType]):
             return top_key, self.data[top_key]
             return top_key, self.data[top_key]
         return None, None
         return None, None
 
 
-    def clear(self):
+    def clear(self) -> None:
         self.data.clear()
         self.data.clear()
         self.key_to_heap.clear()
         self.key_to_heap.clear()
         self.expiration_heap.clear()
         self.expiration_heap.clear()
 
 
-    def __contains__(self, key: KeyType):
+    def __contains__(self, key: KeyType) -> bool:
         self._remove_outdated()
         self._remove_outdated()
         return key in self.data
         return key in self.data
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         self._remove_outdated()
         self._remove_outdated()
         return len(self.data)
         return len(self.data)
 
 
-    def __delitem__(self, key: KeyType):
+    def __delitem__(self, key: KeyType) -> None:
         if key in self.key_to_heap:
         if key in self.key_to_heap:
             del self.data[key], self.key_to_heap[key]
             del self.data[key], self.key_to_heap[key]
         # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
         # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
 
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self.data)
         return bool(self.data)
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self.data})"
         return f"{self.__class__.__name__}({self.data})"
 
 
     @contextmanager
     @contextmanager

+ 22 - 0
pyproject.toml

@@ -8,3 +8,25 @@ line_length = 119
 combine_as_imports = true
 combine_as_imports = true
 combine_star = true
 combine_star = true
 known_local_folder = ["arguments", "test_utils", "tests", "utils"]
 known_local_folder = ["arguments", "test_utils", "tests", "utils"]
+
+
+
+[tool.mypy]
+plugins = [
+  "numpy.typing.mypy_plugin",
+  "pydantic.mypy"
+]
+
+follow_imports = "silent"
+warn_redundant_casts = true
+warn_unused_ignores = true
+disallow_any_generics = true
+check_untyped_defs = true
+no_implicit_reexport = true
+
+
+[tool.pydantic-mypy]
+init_forbid_extra = true
+init_typed = true
+warn_required_dynamic_aliases = true
+warn_untyped_fields = true