|
@@ -2,20 +2,26 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Optional, List
|
|
|
+from typing import AsyncIterator, List, Optional
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import serialize_torch_tensor, nested_flatten, deserialize_torch_tensor, anext, P2P, \
|
|
|
- use_hivemind_log_handler, get_logger
|
|
|
+from hivemind import (
|
|
|
+ P2P,
|
|
|
+ anext,
|
|
|
+ deserialize_torch_tensor,
|
|
|
+ get_logger,
|
|
|
+ nested_flatten,
|
|
|
+ serialize_torch_tensor,
|
|
|
+ use_hivemind_log_handler,
|
|
|
+)
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import StubBase
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
-from src.data_structures import ModuleUID, RPCInfo, RemoteSpanInfo, CHAIN_DELIMITER
|
|
|
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
-
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|