Ver código fonte

fix type check

justheuristic 3 anos atrás
pai
commit
14cbc17150

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.sequence_manager import RemoteSequenceManager
 from src.client.remote_sequential import RemoteSequential

+ 3 - 3
src/client/remote_sequential.py

@@ -12,7 +12,7 @@ from torch import nn
 
 import src
 from src.client.remote_block import RemoteTransformerBlock
-from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 
@@ -44,7 +44,7 @@ class RemoteSequential(nn.Module):
         block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
 
         logger.debug(f"Remote block uids: {block_uids}")
-        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
+        self.remote_sequence_info = RemoteSequenceManager(dht, block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -84,7 +84,7 @@ class RemoteSequential(nn.Module):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
+    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
         self.remote_sequence_info = remote_sequence_info
         self.p2p = p2p
         self.closed = False

+ 5 - 8
src/client/remote_sequence_info.py → src/client/sequence_manager.py

@@ -1,29 +1,26 @@
 from __future__ import annotations
 
 import threading
-from typing import List, NamedTuple, Optional, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple
 
 from hivemind import DHT, PeerID
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState, RemoteSpanInfo
 from src.dht_utils import get_remote_module_infos
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
-
-
-class RemoteSequenceInfo:
+class RemoteSequenceManager:
     """Keeps and updates the meta-information about which peers host which blocks"""
 
     dht: DHT
     block_uids: List[ModuleUID]
     block_infos: List[Optional[RemoteModuleInfo]]
-    spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span]]
+    spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
+    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
     lock_changes: threading.Lock
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):

+ 10 - 0
src/data_structures.py

@@ -23,5 +23,15 @@ class ServerInfo:
 
 @dataclass
 class RemoteModuleInfo:
+    """ A remote module that is served by one or more servers """
     uid: ModuleUID
     servers: Dict[PeerID, ServerInfo]
+
+
+@dataclass
+class RemoteSpanInfo:
+    """ A chain of remote blocks served by one specific remote peer """
+    start: int
+    end: int
+    peer_id: PeerID
+