|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
import contextlib
|
|
|
import logging
|
|
|
import random
|
|
|
-from typing import Optional, Union, List
|
|
|
+from typing import List, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
|
|
@@ -14,7 +14,7 @@ from torch import nn
|
|
|
import src
|
|
|
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
-from src.data_structures import UID_DELIMITER, RemoteSpanInfo, CHAIN_DELIMITER
|
|
|
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, RemoteSpanInfo
|
|
|
from src.dht_utils import _create_remote_modules_from_infos
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
@@ -108,7 +108,7 @@ class RemoteSequential(nn.Module):
|
|
|
class RemoteSequentialInferenceSession:
|
|
|
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
|
|
|
|
|
- def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float]=None):
|
|
|
+ def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
|
|
|
self.sequence_manager = sequence_manager
|
|
|
self.p2p = p2p
|
|
|
self.closed = False
|
|
@@ -125,10 +125,12 @@ class RemoteSequentialInferenceSession:
|
|
|
|
|
|
for chosen_span in self.chosen_spans:
|
|
|
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
|
|
|
- span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start: chosen_span.end])
|
|
|
- inference_session = RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(
|
|
|
- stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
|
|
|
- ))
|
|
|
+ span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
|
|
|
+ inference_session = RemoteExpertWorker.run_coroutine(
|
|
|
+ RemoteTransformerBlockInferenceSession._create(
|
|
|
+ stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
|
|
|
+ )
|
|
|
+ )
|
|
|
self.inference_sessions.append(inference_session)
|
|
|
self.stack.enter_context(inference_session)
|
|
|
|