Explorar o código

style changes

Pavel Samygin %!s(int64=3) %!d(string=hai) anos
pai
achega
28a554b40a

+ 1 - 1
hivemind/moe/client/expert.py

@@ -19,7 +19,7 @@ from hivemind.utils import (
     nested_compare,
     nested_flatten,
     nested_pack,
-    switch_to_uvloop
+    switch_to_uvloop,
 )
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 

+ 8 - 5
hivemind/moe/server/connection_handler.py

@@ -123,9 +123,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         inputs_and_grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         expert = self.experts[request.uid]
         return runtime_pb2.ExpertResponse(
-            tensors=await self._process_inputs(
-                inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema
-            )
+            tensors=await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
         )
 
     async def rpc_backward_partial(
@@ -134,9 +132,14 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         uid, inputs_and_grads = await self._gather_inputs(requests, context)
         expert = self.experts[uid]
         output_split = [
-            p for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
+            p
+            for t in await self._process_inputs(inputs_and_grads, expert.backward_pool, expert.grad_inputs_schema)
             for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)
         ]
 
         async for part in as_aiter(*output_split):
-            yield runtime_pb2.ExpertResponse(tensors=[part, ])
+            yield runtime_pb2.ExpertResponse(
+                tensors=[
+                    part,
+                ]
+            )

+ 5 - 3
hivemind/utils/networking.py

@@ -6,9 +6,11 @@ from typing import Optional, Sequence, Tuple
 from multiaddr import Multiaddr
 
 Hostname, Port = str, int  # flavour types
-Endpoint = Tuple[          # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
-    str, Tuple[str, ...]
-]
+Endpoint = (
+    Tuple[  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
+        str, Tuple[str, ...]
+    ]
+)
 LOCALHOST = "127.0.0.1"