Michael Diskin 4 years ago
parent
commit
e727ec923c

+ 8 - 6
hivemind/averaging/partition.py

@@ -2,7 +2,7 @@
 Auxiliary data structures for AllReduceRunner
 """
 import asyncio
-from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator
+from typing import Sequence, AsyncIterable, Tuple, Optional, TypeVar, Union, AsyncIterator, Type, Deque, List
 from collections import deque
 
 import torch
@@ -32,7 +32,7 @@ class TensorPartContainer:
         self,
         tensors: Sequence[torch.Tensor],
         peer_fractions: Sequence[float],
-        compression_type: Union[type(CompressionType), Sequence[type(CompressionType)]] = CompressionType.NONE,
+        compression_type: Union[Type[CompressionType], Sequence[Type[CompressionType]]] = CompressionType.NONE,
         part_size_bytes: int = 2 ** 20,
         prefetch: int = 1,
     ):
@@ -42,8 +42,10 @@ class TensorPartContainer:
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.compression_type, self.part_size_bytes, self.prefetch = compression_type, part_size_bytes, prefetch
         self.total_size = sum(tensor.numel() for tensor in tensors)
-        self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
-        self._output_parts_by_peer = [deque() for _ in range(self.group_size)]
+        self._input_parts_by_peer: List[Deque[Tuple[torch.Tensor, Type[CompressionType]]]] = [
+            deque() for _ in range(self.group_size)
+        ]
+        self._output_parts_by_peer: List[Deque[torch.Tensor]] = [deque() for _ in range(self.group_size)]
         self._inputs_consumed_by_peer = [False for _ in range(self.group_size)]
         self._output_part_available = [asyncio.Event() for _ in range(self.group_size)]
         self._outputs_registered_by_peer = [0 for _ in range(self.group_size)]
@@ -124,7 +126,7 @@ class TensorPartContainer:
         self._outputs_consumed = True
         peer_index = num_parts_processed = 0
         for tensor_index in range(len(self.local_tensors)):
-            tensor_parts = []
+            tensor_parts: List[torch.Tensor] = []
             while len(tensor_parts) < self.num_parts_by_tensor[tensor_index]:
                 if num_parts_processed >= self.num_parts_by_peer[peer_index]:
                     num_parts_processed = 0
@@ -173,7 +175,7 @@ class TensorPartReducer:
         assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
-        self.accumulator = None  # this will contain the sum of current tensor part from group peers
+        self.accumulator: Optional[torch.Tensor] = None  # contains the sum of current tensor part from group peers
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()

+ 1 - 1
hivemind/dht/node.py

@@ -290,7 +290,7 @@ class DHTNode:
         if k_nearest > beam_size:
             logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
         if node_to_peer_id is None:
-            node_to_peer_id: Dict[DHTID, PeerID] = dict()
+            node_to_peer_id = dict()
             for query in queries:
                 neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
                 node_to_peer_id.update(self._filter_blacklisted(dict(neighbors)))

+ 3 - 3
hivemind/dht/traverse.py

@@ -75,9 +75,9 @@ async def traverse_dht(
     num_workers: int,
     queries_per_call: int,
     get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[Tuple[DHTID], bool]]]],
+    visited_nodes: Dict[DHTID, Set[DHTID]],
     found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
     await_all_tasks: bool = True,
-    visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
 ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
     """
     Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
@@ -109,6 +109,7 @@ async def traverse_dht(
         nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
 
     :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
+
     :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
         for reference. That function implements a special case of traverse_dht with a single query and one worker.
 
@@ -117,13 +118,12 @@ async def traverse_dht(
         visited nodes: { query -> a set of all nodes that received requests for a given query }
     """
     if len(queries) == 0:
-        return {}, dict(visited_nodes or {})
+        return {}, visited_nodes
 
     unfinished_queries = set(queries)  # all queries that haven't triggered finish_search yet
     candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: unvisited nodes, ordered nearest-to-farthest
     nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}  # heap: top-k nearest nodes, farthest-to-nearest
     known_nodes: Dict[DHTID, Set[DHTID]] = {}  # all nodes ever added to the heap (for deduplication)
-    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes or {})  # nodes that were chosen for get_neighbors call
     pending_tasks = set()  # all active tasks (get_neighbors and found_callback)
     active_workers = Counter({q: 0 for q in queries})  # count workers that search for this query
 

+ 2 - 2
hivemind/dht/validation.py

@@ -1,6 +1,6 @@
 import dataclasses
 from abc import ABC, abstractmethod
-from typing import Iterable
+from typing import Iterable, List
 
 
 @dataclasses.dataclass(init=True, repr=True, frozen=True)
@@ -91,7 +91,7 @@ class RecordValidatorBase(ABC):
 
 class CompositeValidator(RecordValidatorBase):
     def __init__(self, validators: Iterable[RecordValidatorBase] = ()):
-        self._validators = []
+        self._validators: List[RecordValidatorBase] = []
         self.extend(validators)
 
     def extend(self, validators: Iterable[RecordValidatorBase]) -> None:

+ 1 - 1
hivemind/moe/server/layers/custom_experts.py

@@ -14,7 +14,7 @@ def add_custom_models_from_file(path: str):
     spec.loader.exec_module(foo)
 
 
-def register_expert_class(name: str, sample_input: Callable[[int, int], torch.tensor]):
+def register_expert_class(name: str, sample_input: Callable[[int, int], torch.Tensor]):
     """
     Adds a custom user expert to hivemind server.
     :param name: the name of the expert. It shouldn't coincide with existing modules\

+ 3 - 3
hivemind/moe/server/task_pool.py

@@ -10,7 +10,7 @@ from abc import ABCMeta, abstractmethod
 from collections import namedtuple
 from concurrent.futures import Future
 from queue import Empty
-from typing import List, Tuple, Dict, Any, Generator
+from typing import List, Tuple, Dict, Any, Generator, Callable
 
 import torch
 
@@ -24,7 +24,7 @@ Task = namedtuple("Task", ("future", "args"))
 class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
     """A pool that accepts tasks and forms batches for parallel processing, interacts with Runtime"""
 
-    def __init__(self, process_func: callable, daemon=True, **kwargs):
+    def __init__(self, process_func: Callable, daemon=True, **kwargs):
         super().__init__(daemon=daemon, **kwargs)
         self.process_func = process_func
         self._priority = mp.Value(ctypes.c_double, 1.0)  # higher priority = the more urgent to process this pool
@@ -73,7 +73,7 @@ class TaskPool(TaskPoolBase):
 
     def __init__(
         self,
-        process_func: callable,
+        process_func: Callable,
         max_batch_size: int,
         name: str,
         min_batch_size=1,

+ 1 - 1
hivemind/optim/performance_ema.py

@@ -11,7 +11,7 @@ class PerformanceEMA:
 
     def __init__(self, alpha: float = 0.1, eps: float = 1e-20):
         self.alpha, self.eps, self.num_updates = alpha, eps, 0
-        self.ema_seconds_per_sample, self.samples_per_second = 0, eps
+        self.ema_seconds_per_sample, self.samples_per_second = 0.0, eps
         self.timestamp = get_dht_time()
         self.paused = False
 

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -43,7 +43,7 @@ class DaemonConnector:
         self.control_maddr = control_maddr
         self.proto_code = parse_conn_protocol(self.control_maddr)
 
-    async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+    async def open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
         if self.proto_code == protocols.P_UNIX:
             control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
             return await asyncio.open_unix_connection(control_path)

+ 2 - 5
hivemind/utils/asyncio.py

@@ -70,14 +70,11 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
 
 
 async def amap_in_executor(
-    func: Callable[..., T],
-    *iterables: AsyncIterable,
-    max_prefetch: Optional[int] = None,
-    executor: Optional[ThreadPoolExecutor] = None
+    func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: int, executor: Optional[ThreadPoolExecutor] = None
 ) -> AsyncIterator[T]:
     """iterate from an async iterable in a background thread, yield results to async iterable"""
     loop = asyncio.get_event_loop()
-    queue = asyncio.Queue(max_prefetch)
+    queue: asyncio.Queue[Optional[Awaitable[T]]] = asyncio.Queue(max_prefetch)
 
     async def _put_items():
         async for args in azip(*iterables):

+ 1 - 1
hivemind/utils/auth.py

@@ -63,7 +63,7 @@ class TokenAuthorizerBase(AuthorizerBase):
         self._local_access_token = None
         self._refresh_lock = asyncio.Lock()
 
-        self._recent_nonces = TimedStorage()
+        self._recent_nonces: TimedStorage[bytes, None] = TimedStorage()
 
     @abstractmethod
     async def get_token(self) -> AccessToken:

+ 2 - 2
hivemind/utils/compression.py

@@ -39,7 +39,7 @@ def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: in
     return lookup
 
 
-def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
+def _quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10 ** 5) -> np.ndarray:
     """Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
     if not array.data.c_contiguous and array.data.f_contiguous:
         array = array.T
@@ -71,7 +71,7 @@ def _get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
 def _uint8_uniform_buckets_encode(tensor: torch.Tensor, range_in_sigmas: float):
     offset = UINT8_RANGE // 2
     shift = tensor.mean()
-    scale = range_in_sigmas * tensor.std() / UINT8_RANGE
+    scale = range_in_sigmas * tensor.std().item() / UINT8_RANGE
 
     quant_weight = torch.quantize_per_tensor(tensor - shift, scale, offset, torch.quint8).int_repr()
     lookup = average_buckets(tensor, quant_weight, UINT8_RANGE)

+ 5 - 5
hivemind/utils/crypto.py

@@ -30,7 +30,7 @@ class PublicKey(ABC):
 
     @classmethod
     @abstractmethod
-    def from_bytes(cls, key: bytes) -> bytes:
+    def from_bytes(cls, key: bytes) -> PublicKey:
         ...
 
 
@@ -97,7 +97,7 @@ class RSAPublicKey(PublicKey):
 
     @classmethod
     def from_bytes(cls, key: bytes) -> RSAPublicKey:
-        key = serialization.load_ssh_public_key(key)
-        if not isinstance(key, rsa.RSAPublicKey):
-            raise ValueError(f"Expected an RSA public key, got {key}")
-        return cls(key)
+        loaded_key = serialization.load_ssh_public_key(key)
+        if not isinstance(loaded_key, rsa.RSAPublicKey):
+            raise ValueError(f"Expected an RSA public key, got {str(key)}")
+        return cls(loaded_key)

+ 4 - 4
hivemind/utils/grpc.py

@@ -96,7 +96,7 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
         cache = cls.get_singleton()
         with cls._lock:
             key = ChannelInfo(target, aio, tuple(options), channel_credentials, compression)
-            entry: ValueWithExpiration = super(cls, cache).get(key)
+            entry: Optional[ValueWithExpiration[Tuple[Any, Dict[Any, Any]]]] = super(cls, cache).get(key)
 
             if entry is not None:
                 channel, stubs = entry.value
@@ -165,13 +165,13 @@ class ChannelCache(TimedStorage[ChannelInfo, Tuple[Union[grpc.Channel, grpc.aio.
             self._is_active = False
             self._update_eviction_evt.set()
 
-    def store(self, *args, **kwargs) -> ValueError:
+    def store(self, *args, **kwargs):
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
-    def get(self, *args, **kwargs) -> ValueError:
+    def get(self, *args, **kwargs):
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")
 
-    def top(self) -> ValueError:
+    def top(self):
         raise ValueError(f"Please use {self.__class__.__name__}.get_stub to get or create stubs")