Michael Diskin 3 年 前
コミット
fe885ecc57

+ 1 - 1
hivemind/averaging/load_balancing.py

@@ -10,7 +10,7 @@ logger = get_logger(__name__)
 LOAD_BALANCING_LP_DECIMALS = 9
 
 
-def load_balance_peers(vector_size, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
+def load_balance_peers(vector_size: int, bandwidths: Sequence[Optional[float]], min_size: int = 0) -> Tuple[int, ...]:
     """
     Find an optimal partitioning of weights for butterfly all-reduce given peer bandwidths.
     :param vector_size: total size of the averaged vector (in elements, not bytes)

+ 2 - 2
hivemind/averaging/partition.py

@@ -140,7 +140,7 @@ class TensorPartContainer:
         self._outputs_consumed = True
         peer_index = num_parts_processed = 0
         for tensor_index in range(len(self.local_tensors)):
-            tensor_parts: List[torch.Tensor] = []
+            tensor_parts = []
             while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
                 if num_parts_processed >= self.num_parts_by_peer[peer_index]:
                     num_parts_processed = 0
@@ -187,7 +187,7 @@ class TensorPartReducer:
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
-        self.accumulator: Optional[torch.Tensor] = None  # contains the sum of current tensor part from group peers
+        self.accumulator = None  # this will contain the sum of current tensor part from group peers
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()

+ 3 - 3
hivemind/dht/traverse.py

@@ -75,9 +75,9 @@ async def traverse_dht(
     num_workers: int,
     queries_per_call: int,
     get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
-    visited_nodes: Dict[DHTID, Set[DHTID]],
     found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
     await_all_tasks: bool = True,
+    visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
     """
     Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
@@ -109,7 +109,6 @@ async def traverse_dht(
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
 
     :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
-
     :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
         for reference. That function implements a special case of traverse_dht with a single query and one worker.
 
@@ -118,12 +117,13 @@ async def traverse_dht(
         visited nodes: { query -> a set of all nodes that received requests for a given query }
     """
     if len(queries) == 0:
-        return {}, visited_nodes
+        return {}, dict(visited_nodes or {})
 
     unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
     candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
     nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
     known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
+    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes or {})  # nodes that were chosen for get_neighbors call
     pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
     active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 

+ 2 - 2
hivemind/dht/validation.py

@@ -1,6 +1,6 @@
 import dataclasses
 from abc import ABC, abstractmethod
-from typing import Iterable, List
+from typing import Iterable
 
 
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
@@ -91,7 +91,7 @@ class RecordValidatorBase(ABC):
 
 class CompositeValidator(RecordValidatorBase):
     def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
-        self._validators: List[RecordValidatorBase] = []
+        self._validators = []
         self.extend(validators)
 
     def extend(self, validators: Iterable[RecordValidatorBase]) -> None:

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -57,7 +57,7 @@ class DaemonConnector:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
-    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+    async def open_persistent_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
         """
         Open connection to daemon and upgrade it to a persistent one
         """

+ 5 - 5
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -55,13 +55,13 @@ class PeerID:
     def __repr__(self) -> str:
         return f"<libp2p.peer.id.ID ({self.to_base58()})>"
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.to_base58()
 
-    def pretty(self):
+    def pretty(self) -> str:
         return self.to_base58()
 
-    def to_string(self):
+    def to_string(self) -> str:
         return self.to_base58()
 
     def __eq__(self, other: object) -> bool:
@@ -128,10 +128,10 @@ class PeerInfo:
         addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
         return PeerInfo(peer_id, addrs)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})"
 
 

+ 1 - 1
hivemind/utils/asyncio.py

@@ -111,7 +111,7 @@ async def amap_in_executor(
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()
-    queue: asyncio.Queue[Optional[Awaitable[T]]] = asyncio.Queue(max_prefetch)
+    queue = asyncio.Queue(max_prefetch)
 
     async def _put_items():
         try:

+ 1 - 1
hivemind/utils/auth.py

@@ -63,7 +63,7 @@ class TokenAuthorizerBase(AuthorizerBase):
         self._local_access_token = None
         self._refresh_lock = asyncio.Lock()
 
-        self._recent_nonces: TimedStorage[bytes, None] = TimedStorage()
+        self._recent_nonces = TimedStorage()
 
     @abstractmethod
     async def get_token(self) -> AccessToken:

+ 8 - 7
hivemind/utils/crypto.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import base64
 import threading
 from abc import ABC, abstractmethod
+from typing import Any, Dict
 
 from cryptography import exceptions
 from cryptography.hazmat.primitives import hashes, serialization
@@ -39,7 +40,7 @@ _RSA_HASH_ALGORITHM = hashes.SHA256()
 
 
 class RSAPrivateKey(PrivateKey):
-    def __init__(self):
+    def __init__(self) -> None:
         self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
 
     _process_wide_key = None
@@ -60,7 +61,7 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
         state["_private_key"] = self._private_key.private_bytes(
@@ -70,13 +71,13 @@ class RSAPrivateKey(PrivateKey):
         )
         return state
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self.__dict__.update(state)
         self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
 
 
 class RSAPublicKey(PublicKey):
-    def __init__(self, public_key: rsa.RSAPublicKey):
+    def __init__(self, public_key: rsa.RSAPublicKey) -> None:
         self._public_key = public_key
 
     def verify(self, data: bytes, signature: bytes) -> bool:
@@ -97,7 +98,7 @@ class RSAPublicKey(PublicKey):
 
     @classmethod
     def from_bytes(cls, key: bytes) -> RSAPublicKey:
-        loaded_key = serialization.load_ssh_public_key(key)
-        if not isinstance(loaded_key, rsa.RSAPublicKey):
+        deserialized_key = serialization.load_ssh_public_key(key)
+        if not isinstance(deserialized_key, rsa.RSAPublicKey):
             raise ValueError(f"Expected an RSA public key, got {str(key)}")
-        return cls(loaded_key)
+        return cls(deserialized_key)

+ 3 - 3
hivemind/utils/grpc.py

@@ -54,7 +54,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     _lock: threading.RLock = threading.RLock()
     _update_eviction_evt: threading.Event = threading.Event()
 
-    def __init__(self, _created_as_singleton=False):
+    def __init__(self, _created_as_singleton=False) -> None:
         assert _created_as_singleton, f"Please use {self.__class__.__name__}.get_singleton()"
         super().__init__(maxsize=self.MAXIMUM_CHANNELS)
         self._is_active = True
@@ -76,7 +76,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
     def get_stub(
         cls,
         target: Endpoint,
-        stub_type: Type[Stub],
+        stub_type: type,
         *,
         aio: bool,
         options: Tuple[Tuple[str, Any]] = (),
@@ -146,7 +146,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
                 target, credentials=channel_credentials, options=options, compression=compression
             )
 
-    def _evict_stale_channels_in_background(self):
+    def _evict_stale_channels_in_background(self) -> None:
         while self._is_active:
             now = get_dht_time()
             time_to_wait = max(0.0, self._nearest_expiration_time - now)

+ 1 - 1
hivemind/utils/limits.py

@@ -3,7 +3,7 @@ from hivemind.utils.logging import get_logger
 logger = get_logger(__name__)
 
 
-def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+def increase_file_limit(new_soft: int = 2 ** 15, new_hard: int = 2 ** 15) -> None:
     """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads."""
     try:
         import resource  # local import to avoid ImportError for Windows users

+ 1 - 1
hivemind/utils/logging.py

@@ -83,7 +83,7 @@ class CustomFormatter(logging.Formatter):
         return super().format(record)
 
 
-def _initialize_if_necessary():
+def _initialize_if_necessary() -> None:
     global _current_mode, _default_handler
 
     with _init_lock:

+ 15 - 16
hivemind/utils/mpfuture.py

@@ -3,7 +3,6 @@ from __future__ import annotations
 import asyncio
 import concurrent.futures._base as base
 import multiprocessing as mp
-import multiprocessing.connection
 import os
 import threading
 import uuid
@@ -131,13 +130,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
             self._set_event_threadsafe()
 
-    def _set_event_threadsafe(self):
+    def _set_event_threadsafe(self) -> None:
         try:
             running_loop = asyncio.get_running_loop()
         except RuntimeError:
             running_loop = None
 
-        async def _event_setter():
+        async def _event_setter() -> None:
             self._aio_event.set()
 
         if self._loop.is_closed():
@@ -150,7 +149,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._loop.run_until_complete(_event_setter())
 
     @classmethod
-    def _initialize_mpfuture_backend(cls):
+    def _initialize_mpfuture_backend(cls) -> None:
         pid = os.getpid()
         logger.debug(f"Initializing MPFuture backend for pid {pid}")
 
@@ -162,7 +161,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         cls._pipe_waiter_thread.start()
 
     @staticmethod
-    def reset_backend():
+    def reset_backend() -> None:
         """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
         MPFuture._active_pid = None
         MPFuture._initialization_lock = mp.Lock()
@@ -200,7 +199,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
-    def _send_update(self, update_type: UpdateType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None) -> None:
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -208,7 +207,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
 
-    def set_result(self, result: ResultType):
+    def set_result(self, result: ResultType) -> None:
         if os.getpid() == self._origin_pid:
             super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
@@ -218,7 +217,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._state_cache[self._state], self._result = base.FINISHED, result
             self._send_update(UpdateType.RESULT, result)
 
-    def set_exception(self, exception: Optional[BaseException]):
+    def set_exception(self, exception: Optional[BaseException]) -> None:
         if os.getpid() == self._origin_pid:
             super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
@@ -239,7 +238,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             self._send_update(UpdateType.CANCEL)
             return True
 
-    def set_running_or_notify_cancel(self):
+    def set_running_or_notify_cancel(self) -> bool:
         if self._state == base.PENDING:
             self._state = base.RUNNING
             return True
@@ -274,18 +273,18 @@ class MPFuture(base.Future, Generic[ResultType]):
     def done(self) -> bool:
         return self._state in TERMINAL_STATES
 
-    def running(self):
+    def running(self) -> bool:
         return self._state == base.RUNNING
 
-    def cancelled(self):
+    def cancelled(self) -> bool:
         return self._state == base.CANCELLED
 
-    def add_done_callback(self, callback: Callable[[MPFuture], None]):
+    def add_done_callback(self, callback: Callable[[MPFuture], None]) -> None:
         if os.getpid() != self._origin_pid:
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
 
-    def __await__(self):
+    def __await__(self) -> Any:
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
         yield from self._aio_event.wait().__await__()
@@ -294,13 +293,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         except base.CancelledError:
             raise asyncio.CancelledError()
 
-    def __del__(self):
+    def __del__(self) -> None:
         if getattr(self, "_origin_pid", None) == os.getpid():
             MPFuture._active_futures.pop(self._uid, None)
         if getattr(self, "_aio_event", None):
             self._aio_event.set()
 
-    def __getstate__(self):
+    def __getstate__(self) -> Dict[str, Any]:
         return dict(
             _sender_pipe=self._sender_pipe,
             _shared_state_code=self._shared_state_code,
@@ -311,7 +310,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             _exception=self._exception,
         )
 
-    def __setstate__(self, state):
+    def __setstate__(self, state: Dict[str, Any]) -> None:
         self._sender_pipe = state["_sender_pipe"]
         self._shared_state_code = state["_shared_state_code"]
         self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]

+ 6 - 5
hivemind/utils/nested.py

@@ -1,7 +1,8 @@
 """ utility functions that help you process nested dicts, tuples, lists and namedtuples """
+from typing import Any
 
 
-def nested_compare(t, u):
+def nested_compare(t: Any, u: Any) -> bool:
     """
     Return whether nested structure of t1 and t2 matches.
     """
@@ -29,7 +30,7 @@ def nested_compare(t, u):
         return True
 
 
-def nested_flatten(t):
+def nested_flatten(t: Any) -> Any:
     """
     Turn nested list/tuple/dict into a flat iterator.
     """
@@ -43,7 +44,7 @@ def nested_flatten(t):
         yield t
 
 
-def nested_pack(flat, structure):
+def nested_pack(flat: Any, structure: Any) -> Any:
     """
     Restore nested structure from flattened state
     :param flat: result of nested_flatten
@@ -53,7 +54,7 @@ def nested_pack(flat, structure):
     return _nested_pack(iter(flat), structure)
 
 
-def _nested_pack(flat_iter, structure):
+def _nested_pack(flat_iter: Any, structure: Any) -> Any:
     if is_namedtuple(structure):
         return type(structure)(*[_nested_pack(flat_iter, x) for x in structure])
     elif isinstance(structure, (list, tuple)):
@@ -64,7 +65,7 @@ def _nested_pack(flat_iter, structure):
         return next(flat_iter)
 
 
-def is_namedtuple(x):
+def is_namedtuple(x) -> bool:
     """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 ."""
     t = type(x)
     b = t.__bases__

+ 5 - 2
hivemind/utils/networking.py

@@ -5,7 +5,8 @@ from typing import Optional, Sequence
 
 from multiaddr import Multiaddr
 
-Hostname, Port = str, int  # flavour types
+Hostname = str
+Port = int  # flavour types
 Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = "127.0.0.1"
 
@@ -30,7 +31,9 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
-def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(
+    params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+) -> Port:
     """
     Finds a tcp port that can be occupied with a socket with *params and use *opt options.
 

+ 2 - 2
hivemind/utils/performance_ema.py

@@ -37,7 +37,7 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
 
-    def reset_timer(self):
+    def reset_timer(self) -> None:
         """Reset the time since the last update so that the next task performance is counted from current time"""
         self.timestamp = time.perf_counter()
 
@@ -51,7 +51,7 @@ class PerformanceEMA:
             self.paused = was_paused
             self.reset_timer()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
 
     @contextmanager

+ 4 - 4
hivemind/utils/serializer.py

@@ -1,6 +1,6 @@
 """ A unified interface for several common serialization methods """
 from abc import ABC, abstractmethod
-from typing import Any, Dict
+from typing import Any, Callable, Dict
 
 import msgpack
 
@@ -27,10 +27,10 @@ class MSGPackSerializer(SerializerBase):
     _TUPLE_EXT_TYPE_CODE = 0x40
 
     @classmethod
-    def ext_serializable(cls, type_code: int):
+    def ext_serializable(cls, type_code: int) -> Callable[[type], type]:
         assert isinstance(type_code, int), "Please specify a (unique) int type code"
 
-        def wrap(wrapped_type: type):
+        def wrap(wrapped_type: type) -> type:
             assert callable(getattr(wrapped_type, "packb", None)) and callable(
                 getattr(wrapped_type, "unpackb", None)
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
@@ -42,7 +42,7 @@ class MSGPackSerializer(SerializerBase):
         return wrap
 
     @classmethod
-    def _encode_ext_types(cls, obj):
+    def _encode_ext_types(cls, obj: object):
         type_code = cls._ext_types.get(type(obj))
         if type_code is not None:
             return msgpack.ExtType(type_code, obj.packb())

+ 12 - 12
hivemind/utils/timed_storage.py

@@ -5,7 +5,7 @@ import heapq
 import time
 from contextlib import contextmanager
 from dataclasses import dataclass
-from typing import Dict, Generic, Iterator, List, Optional, Tuple, TypeVar
+from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
 
 KeyType = TypeVar("KeyType")
 ValueType = TypeVar("ValueType")
@@ -20,10 +20,10 @@ class ValueWithExpiration(Generic[ValueType]):
     value: ValueType
     expiration_time: DHTExpiration
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Union[ValueType, DHTExpiration]]:
         return iter((self.value, self.expiration_time))
 
-    def __getitem__(self, item):
+    def __getitem__(self, item: Union[int, str]) -> Union[ValueType, DHTExpiration]:
         if item == 0:
             return self.value
         elif item == 1:
@@ -31,7 +31,7 @@ class ValueWithExpiration(Generic[ValueType]):
         else:
             return getattr(self, item)
 
-    def __eq__(self, item):
+    def __eq__(self, item: Any) -> bool:
         if isinstance(item, ValueWithExpiration):
             return self.value == item.value and self.expiration_time == item.expiration_time
         elif isinstance(item, tuple):
@@ -51,13 +51,13 @@ class TimedStorage(Generic[KeyType, ValueType]):
 
     frozen = False  # can be set to True. If true, do not remove outdated elements
 
-    def __init__(self, maxsize: Optional[int] = None):
+    def __init__(self, maxsize: Optional[int] = None) -> None:
         self.maxsize = maxsize or float("inf")
         self.data: Dict[KeyType, ValueWithExpiration[ValueType]] = dict()
         self.expiration_heap: List[HeapEntry[KeyType]] = []
         self.key_to_heap: Dict[KeyType, HeapEntry[KeyType]] = dict()
 
-    def _remove_outdated(self):
+    def _remove_outdated(self) -> None:
         while (
             not self.frozen
             and self.expiration_heap
@@ -108,28 +108,28 @@ class TimedStorage(Generic[KeyType, ValueType]):
             return top_key, self.data[top_key]
         return None, None
 
-    def clear(self):
+    def clear(self) -> None:
         self.data.clear()
         self.key_to_heap.clear()
         self.expiration_heap.clear()
 
-    def __contains__(self, key: KeyType):
+    def __contains__(self, key: KeyType) -> bool:
         self._remove_outdated()
         return key in self.data
 
-    def __len__(self):
+    def __len__(self) -> int:
         self._remove_outdated()
         return len(self.data)
 
-    def __delitem__(self, key: KeyType):
+    def __delitem__(self, key: KeyType) -> None:
         if key in self.key_to_heap:
             del self.data[key], self.key_to_heap[key]
         # note: key may still be in self.expiration_heap, but it will not be used and eventually ._remove_outdated()
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self.data)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self.data})"
 
     @contextmanager

+ 22 - 0
pyproject.toml

@@ -8,3 +8,25 @@ line_length = 119
 combine_as_imports = true
 combine_star = true
 known_local_folder = ["arguments", "test_utils", "tests", "utils"]
+
+
+
+[tool.mypy]
+plugins = [
+  "numpy.typing.mypy_plugin",
+  "pydantic.mypy"
+]
+
+follow_imports = "silent"
+warn_redundant_casts = true
+warn_unused_ignores = true
+disallow_any_generics = true
+check_untyped_defs = true
+no_implicit_reexport = true
+
+
+[tool.pydantic-mypy]
+init_forbid_extra = true
+init_typed = true
+warn_required_dynamic_aliases = true
+warn_untyped_fields = true