justheuristic před 3 roky
rodič
revize
ec20fa6ec5

+ 1 - 1
src/client/__init__.py

@@ -1,5 +1,5 @@
+from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
-from src.client.inference_session import RemoteTransformerBlockInferenceSession, RemoteSequentialInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential
 from src.client.sequence_manager import RemoteSequenceManager

+ 11 - 5
src/client/inference_session.py

@@ -2,20 +2,26 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
-from typing import AsyncIterator, Optional, List
+from typing import AsyncIterator, List, Optional
 
 import torch
-from hivemind import serialize_torch_tensor, nested_flatten, deserialize_torch_tensor, anext, P2P, \
-    use_hivemind_log_handler, get_logger
+from hivemind import (
+    P2P,
+    anext,
+    deserialize_torch_tensor,
+    get_logger,
+    nested_flatten,
+    serialize_torch_tensor,
+    use_hivemind_log_handler,
+)
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 
 from src.client.sequence_manager import RemoteSequenceManager
-from src.data_structures import ModuleUID, RPCInfo, RemoteSpanInfo, CHAIN_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 0 - 2
src/client/remote_block.py

@@ -42,5 +42,3 @@ class RemoteTransformerBlock(RemoteExpert):
     def begin_inference_session(self):
         logger.warning("beging_inference_session was renamed to just inference_session")
         return self.inference_session()
-
-