소스 검색

style fix

Pavel Samygin 3 년 전
부모
커밋
c7c877533b

+ 1 - 2
benchmarks/benchmark_throughput_p2p.py

@@ -113,7 +113,6 @@ def benchmark_throughput(
         for client in clients:
             client.start()
 
-
         timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
 
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -130,7 +129,6 @@ def benchmark_throughput(
             )
         timestamps["created_experts"] = time.perf_counter()
 
-
         server = hivemind.moe.Server(
             dht=server_dht,
             expert_backends=experts,
@@ -198,6 +196,7 @@ def benchmark_throughput(
 
     assert not benchmarking_failed.is_set()
 
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--preset", type=str, default="default", required=False)

+ 9 - 9
hivemind/moe/client/beam_search.py

@@ -261,10 +261,8 @@ class MoEBeamSearcher:
 
         p2p = _RemoteModuleCall.run_coroutine(self.dht.replicate_p2p())
         if return_future:
-            return LazyFutureCaller(
-                result,
-                lambda lst: [l.get(p2p=p2p) for l in lst]
-            )
+            return LazyFutureCaller(result, lambda lst: [l.get(p2p=p2p) for l in lst])
+
         return [r.get(p2p=p2p) for r in result]
 
     @classmethod
@@ -332,11 +330,13 @@ class MoEBeamSearcher:
                 unique_experts.add(uid_endpoint.uid)
 
         best_experts = [
-            LazyValue(init=partial(
-                RemoteExpert,
-                uid=uid_endpoint.uid,
-                server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
-            ))
+            LazyValue(
+                init=partial(
+                    RemoteExpert,
+                    uid=uid_endpoint.uid,
+                    server_peer_info=PeerInfo.from_endpoint(uid_endpoint.endpoint),
+                )
+            )
             for _, uid_endpoint in sorted(best_experts_heap, reverse=True)
         ]
         return best_experts

+ 29 - 23
hivemind/moe/client/expert.py

@@ -12,7 +12,15 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MSGPackSerializer, amap_in_executor, as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils import (
+    MSGPackSerializer,
+    amap_in_executor,
+    as_aiter,
+    nested_compare,
+    nested_flatten,
+    nested_pack,
+    switch_to_uvloop
+)
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
@@ -144,28 +152,27 @@ class _RemoteModuleCall(torch.autograd.Function):
         return tuple(deserialized_outputs)
 
     @classmethod
-    def forward_partial(
-        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
-    ) -> List[torch.Tensor]:
+    def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
         split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
 
         outputs = cls.run_coroutine(
             stub.rpc_forward_partial(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
-                    as_aiter(*split)
+                    lambda t: runtime_pb2.ExpertRequest(
+                        uid=ctx.uid,
+                        tensors=[
+                            t,
+                        ],
+                    ),
+                    as_aiter(*split),
                 ),
             )
         )
 
-        return cls.run_coroutine(
-            gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
-        )
+        return cls.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
 
     @classmethod
-    def forward_oneshot(
-        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
-    ) -> List[torch.Tensor]:
+    def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
 
         outputs = cls.run_coroutine(
             stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
@@ -197,29 +204,28 @@ class _RemoteModuleCall(torch.autograd.Function):
 
     @classmethod
     @once_differentiable
-    def backward_partial(
-        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
-    ) -> List[torch.Tensor]:
+    def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
 
         grad_inputs = cls.run_coroutine(
             ctx.stub.rpc_backward_partial(
                 amap_in_executor(
-                    lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
-                    as_aiter(*split)
+                    lambda t: runtime_pb2.ExpertRequest(
+                        uid=ctx.uid,
+                        tensors=[
+                            t,
+                        ],
+                    ),
+                    as_aiter(*split),
                 ),
             )
         )
 
-        return cls.run_coroutine(
-            gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
-        )
+        return cls.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
 
     @classmethod
     @once_differentiable
-    def backward_oneshot(
-        cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
-    ) -> List[torch.Tensor]:
+    def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
         grad_inputs = cls.run_coroutine(
             ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
         )

+ 15 - 11
hivemind/moe/server/connection_handler.py

@@ -61,7 +61,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
 
     class _RequestUnpacker:
 
-        __slots__ = "uid",
+        __slots__ = ("uid",)
 
         def __init__(self):
             self.uid = None
@@ -82,22 +82,21 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         return unpacker.uid, inputs
 
     async def _process_inputs(
-        self, inputs: List[torch.Tensor], pool: TaskPool, schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]]
-    ):
+        self,
+        inputs: List[torch.Tensor],
+        pool: TaskPool,
+        schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]],
+    ) -> List[runtime_pb2.Tensor]:
         return [
             serialize_torch_tensor(t, p.compression, allow_inplace=True)
             for t, p in zip(await pool.submit_task(*inputs), nested_flatten(schema))
         ]
 
-    async def rpc_forward(
-        self, request: runtime_pb2.ExpertRequest, context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         expert = self.experts[request.uid]
         return runtime_pb2.ExpertResponse(
-            tensors=await self._process_inputs(
-                inputs, expert.forward_pool, expert.outputs_schema
-            )
+            tensors=await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
         )
 
     async def rpc_forward_partial(
@@ -106,12 +105,17 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         uid, inputs = await self._gather_inputs(requests, context)
         expert = self.experts[uid]
         output_split = [
-            p for t in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
+            p
+            for t in await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
             for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
         ]
 
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part, ])
+            yield runtime_pb2.ExpertResponse(
+                tensors=[
+                    part,
+                ],
+            )
 
     async def rpc_backward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext

+ 4 - 3
hivemind/moe/server/dht_handler.py

@@ -51,7 +51,8 @@ def declare_experts(
         assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
     addrs = tuple(str(a.decapsulate("/p2p/" + a.get("p2p"))) for a in dht.get_visible_maddrs())
     return dht.run_coroutine(
-        partial(_declare_experts, uids=list(uids), peer_id=peer_id, addrs=addrs, expiration=expiration), return_future=not wait
+        partial(_declare_experts, uids=list(uids), peer_id=peer_id, addrs=addrs, expiration=expiration),
+        return_future=not wait,
     )
 
 
@@ -104,8 +105,8 @@ async def _get_experts(
 
     experts: List[Optional[RemoteExpert]] = [None] * len(uids)
     for i, uid in enumerate(uids):
-        if (elem := found[uid]) is not None and \
-            isinstance(elem.value, tuple):
+        elem = found[uid]
+        if elem is not None and isinstance(elem.value, tuple):
             peer_id, addrs = elem.value
             peer_info = PeerInfo(peer_id=PeerID.from_base58(peer_id), addrs=tuple(Multiaddr(a) for a in addrs))
             experts[i] = LazyValue(init=partial(RemoteExpert, uid=uid, server_peer_info=peer_info))

+ 1 - 0
hivemind/moe/server/server.py

@@ -348,6 +348,7 @@ def _server_runner(pipe, *args, **kwargs):
         server.join()
         logger.info("Server shut down.")
 
+
 def _generate_uids(
     num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
 ) -> List[str]:

+ 3 - 1
hivemind/p2p/p2p_daemon.py

@@ -518,7 +518,9 @@ class P2P:
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False) -> None:
+    async def add_binary_stream_handler(
+        self, name: str, handler: p2pclient.StreamHandler, balanced: bool = False
+    ) -> None:
         if self._listen_task is None:
             self._start_listening()
         await self._client.stream_handler(name, handler, balanced)

+ 3 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -362,7 +362,9 @@ class ControlClient:
         reader, writer = await self.daemon_connector.open_connection()
 
         listen_path_maddr_bytes = self.listen_maddr.to_bytes()
-        stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced)
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto], balanced=balanced
+        )
         req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
         await write_pbmsg(writer, req)
 

+ 4 - 2
hivemind/p2p/servicer.py

@@ -104,7 +104,9 @@ class ServicerBase:
         caller.__name__ = handler.method_name
         return caller
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False) -> None:
+    async def add_p2p_handlers(
+        self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None, balanced: bool = False
+    ) -> None:
         self._collect_rpc_handlers()
 
         servicer = self if wrapper is None else wrapper
@@ -116,7 +118,7 @@ class ServicerBase:
                     handler.request_type,
                     stream_input=handler.stream_input,
                     stream_output=handler.stream_output,
-                    balanced=balanced
+                    balanced=balanced,
                 )
                 for handler in self._rpc_handlers
             ]

+ 16 - 1
hivemind/utils/grpc.py

@@ -7,7 +7,21 @@ from __future__ import annotations
 import os
 import threading
 import torch
-from typing import Callable, AsyncIterator, Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+    Callable,
+    AsyncIterator,
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    NamedTuple,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 
 import grpc
 
@@ -213,6 +227,7 @@ def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.
 
 RpcMessage = TypeVar("RpcMessage")
 
+
 async def gather_from_grpc(
     stream: AsyncIterator[RpcMessage],
     key: Callable[[RpcMessage], Iterable[runtime_pb2.Tensor]],

+ 3 - 3
hivemind/utils/lazy_value.py

@@ -4,6 +4,7 @@ from hivemind.utils.mpfuture import MPFuture
 
 T = TypeVar("T")
 
+
 class _Empty(Generic[T]):
 
     _instance = None
@@ -14,9 +15,7 @@ class _Empty(Generic[T]):
         return cls._instance
 
 
-
 class LazyValue(Generic[T]):
-
     def __init__(self, value: T = _Empty(), init: Optional[Callable[..., T]] = None):
         assert value != _Empty() or init is not None, "One should provide either value or intializer"
         self.value = value
@@ -28,10 +27,11 @@ class LazyValue(Generic[T]):
 
         return self.value
 
+
 RT = TypeVar("RT")
 
-class LazyFutureCaller(Generic[T, RT]):
 
+class LazyFutureCaller(Generic[T, RT]):
     def __init__(self, future: MPFuture[T], callback: Optional[Callable[[T], RT]] = None):
         self._fut = future
         self._cb = callback

+ 3 - 1
hivemind/utils/networking.py

@@ -6,7 +6,9 @@ from typing import Optional, Sequence, Tuple
 from multiaddr import Multiaddr
 
 Hostname, Port = str, int  # flavour types
-Endpoint = Tuple[str, Tuple[str, ...]]  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
+Endpoint = Tuple[          # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
+    str, Tuple[str, ...]
+]
 LOCALHOST = "127.0.0.1"