Bläddra i källkod

style fixes, delete unused lazy value

Pavel Samygin 3 år sedan
förälder
incheckning
6b086e4793

+ 0 - 1
benchmarks/benchmark_throughput.py

@@ -3,7 +3,6 @@ import multiprocessing as mp
 import random
 import sys
 import time
-from grpc import server
 
 import torch
 

+ 17 - 12
hivemind/moe/client/expert.py

@@ -1,19 +1,19 @@
-from dataclasses import dataclass
+import os
 from concurrent.futures import Future
+from dataclasses import dataclass
 from lib2to3.pgen2.token import OP
-from multiaddr import Multiaddr
-import os
 from queue import Queue
 from threading import Thread
-from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
 
 import torch
 import torch.nn as nn
+from multiaddr import Multiaddr
 from torch.autograd.function import once_differentiable
 
 import hivemind
-from hivemind.dht import DHT
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
+from hivemind.dht import DHT
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
@@ -36,6 +36,7 @@ DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autogra
 def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHandlerStub:
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
+
 @dataclass(frozen=True)
 class RemoteExpertInfo:
     uid: str
@@ -45,8 +46,7 @@ class RemoteExpertInfo:
     @property
     def as_peer_info(self) -> Tuple[str, PeerInfo]:
         return self.uid, PeerInfo(
-            peer_id=PeerID.from_base58(self.peer_id),
-            addrs=tuple(Multiaddr(a) for a in self.addrs)
+            peer_id=PeerID.from_base58(self.peer_id), addrs=tuple(Multiaddr(a) for a in self.addrs)
         )
 
 
@@ -103,7 +103,6 @@ class RemoteExpertWorker:
     _event_thread: Optional[Thread] = None
     _pid: int = 0
 
-
     @classmethod
     def _run(cls):
         loop = switch_to_uvloop()
@@ -137,9 +136,12 @@ class RemoteExpertWorker:
         return result
 
     @classmethod
-    def spawn_experts_future(cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT) -> MPFuture[List[Optional[RemoteExpert]]]:
+    def spawn_experts_future(
+        cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
+    ) -> MPFuture[List[Optional[RemoteExpert]]]:
         async def _unpack():
             return cls.spawn_experts(await infos, dht)
+
         return cls.run_coroutine(_unpack, True)
 
     @classmethod
@@ -155,7 +157,6 @@ class RemoteExpertWorker:
         return experts
 
 
-
 class _RemoteModuleCall(torch.autograd.Function):
     """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
 
@@ -209,7 +210,9 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        return RemoteExpertWorker.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(
+            gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
+        )
 
     @classmethod
     def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
@@ -261,7 +264,9 @@ class _RemoteModuleCall(torch.autograd.Function):
             )
         )
 
-        return RemoteExpertWorker.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
+        return RemoteExpertWorker.run_coroutine(
+            gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
+        )
 
     @classmethod
     @once_differentiable

+ 4 - 6
hivemind/moe/server/dht_handler.py

@@ -2,8 +2,6 @@ import threading
 from functools import partial
 from typing import Dict, List, Optional, Sequence, Tuple, Union
 
-from multiaddr import Multiaddr
-
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertInfo, RemoteExpertWorker
 from hivemind.moe.server.expert_uid import (
@@ -16,8 +14,8 @@ from hivemind.moe.server.expert_uid import (
     is_valid_uid,
     split_uid,
 )
-from hivemind.p2p import PeerID, PeerInfo
-from hivemind.utils import get_dht_time, LazyFutureCaller, LazyValue
+from hivemind.p2p import PeerID
+from hivemind.utils import get_dht_time, MPFuture
 
 
 class DHTHandlerThread(threading.Thread):
@@ -37,7 +35,7 @@ class DHTHandlerThread(threading.Thread):
 
 def declare_experts(
     dht: DHT, uids: Sequence[ExpertUID], peer_id: PeerID, expiration: DHTExpiration = 300, wait: bool = True
-) -> Dict[ExpertUID, bool]:
+) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]:
     """
     Make experts visible to all DHT peers; update timestamps if declared previously.
 
@@ -77,7 +75,7 @@ async def _declare_experts(
 
 def get_experts(
     dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
-) -> Union[List[Optional[RemoteExpert]], LazyFutureCaller[Optional[LazyValue[RemoteExpert]], Optional[RemoteExpert]]]:
+) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)

+ 0 - 1
hivemind/utils/__init__.py

@@ -1,6 +1,5 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
-from hivemind.utils.lazy_value import LazyValue, LazyFutureCaller
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *

+ 0 - 43
hivemind/utils/lazy_value.py

@@ -1,43 +0,0 @@
-from typing import Any, Generic, TypeVar, Callable, Optional, Union
-
-from hivemind.utils.mpfuture import MPFuture
-
-T = TypeVar("T")
-
-
-class _Empty(Generic[T]):
-
-    _instance = None
-
-    def __new__(cls, *args, **kwargs):
-        if cls._instance is None:
-            cls._instance = super(_Empty, cls).__new__(cls, *args, **kwargs)
-        return cls._instance
-
-
-class LazyValue(Generic[T]):
-    def __init__(self, value: T = _Empty(), init: Optional[Callable[..., T]] = None):
-        assert value != _Empty() or init is not None, "One should provide either value or intializer"
-        self.value = value
-        self.init = init
-
-    def get(self, *args, **kwargs) -> T:
-        if self.value == _Empty():
-            self.value = self.init(*args, **kwargs)
-
-        return self.value
-
-
-RT = TypeVar("RT")
-
-
-class LazyFutureCaller(Generic[T, RT]):
-    def __init__(self, future: MPFuture[T], callback: Optional[Callable[[T], RT]] = None):
-        self._fut = future
-        self._cb = callback
-
-    def result(self) -> Union[T, RT]:
-        result = self._fut.result()
-        if self._cb is not None:
-            return self._cb(result)
-        return result