Pavel Samygin 3 лет назад
Родитель
Сommit
222985cadd

+ 1 - 1
hivemind/moe/client/beam_search.py

@@ -18,7 +18,7 @@ from hivemind.moe.server.expert_uid import (
     is_valid_prefix,
     is_valid_prefix,
 )
 )
 from hivemind.p2p import PeerInfo
 from hivemind.p2p import PeerInfo
-from hivemind.utils import get_dht_time, get_logger, LazyFutureCaller, LazyValue
+from hivemind.utils import LazyFutureCaller, LazyValue, get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -1,5 +1,5 @@
+from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.dht_handler import declare_experts, get_experts
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.moe.server.layers import register_expert_class
 from hivemind.moe.server.layers import register_expert_class
 from hivemind.moe.server.server import Server, background_server
 from hivemind.moe.server.server import Server, background_server
-from hivemind.moe.server.connection_handler import ConnectionHandler

+ 2 - 2
hivemind/moe/server/connection_handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
-from typing import AsyncIterator, Dict, Iterable, Union, Tuple, List
+from typing import AsyncIterator, Dict, Iterable, List, Tuple, Union
 
 
 import torch
 import torch
 
 
@@ -11,7 +11,7 @@ from hivemind.moe.server.task_pool import TaskPool
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MSGPackSerializer, MPFuture, as_aiter, get_logger, nested_flatten
+from hivemind.utils import MPFuture, MSGPackSerializer, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.asyncio import switch_to_uvloop
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor

+ 2 - 1
hivemind/moe/server/dht_handler.py

@@ -1,7 +1,8 @@
 import threading
 import threading
 from functools import partial
 from functools import partial
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
-from typing import Union, Dict, List, Optional, Sequence, Tuple
 
 
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTExpiration, DHTNode, DHTValue
 from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall
 from hivemind.moe.client.expert import RemoteExpert, _RemoteModuleCall

+ 2 - 2
hivemind/moe/server/server.py

@@ -1,11 +1,11 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import multiprocessing as mp
 import multiprocessing as mp
+import random
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
-import random
 from typing import Dict, List, Optional, Tuple
 from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
@@ -25,9 +25,9 @@ from hivemind.moe.server.layers import (
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
+from hivemind.utils import Endpoint
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
-from hivemind.utils import Endpoint
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,5 +1,6 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
+from hivemind.utils.lazy_value import LazyValue, LazyFutureCaller
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
@@ -9,4 +10,3 @@ from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *
-from hivemind.utils.lazy_value import LazyValue, LazyFutureCaller

+ 3 - 3
hivemind/utils/grpc.py

@@ -6,11 +6,10 @@ from __future__ import annotations
 
 
 import os
 import os
 import threading
 import threading
-import torch
 from typing import (
 from typing import (
-    Callable,
-    AsyncIterator,
     Any,
     Any,
+    AsyncIterator,
+    Callable,
     Dict,
     Dict,
     Iterable,
     Iterable,
     Iterator,
     Iterator,
@@ -24,6 +23,7 @@ from typing import (
 )
 )
 
 
 import grpc
 import grpc
+import torch
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger