Browse Source

design-doc

justheuristic 3 years ago
parent
commit
0613095168

+ 37 - 0
src/client/TODO_design.py

@@ -0,0 +1,37 @@
+# idea:
+# class RemoteSequence:
+#     """A chain of specific remote peers; created by RemoteSequenceManager.make_sequence()"""
+#     spans: Sequence[Span] # spans that describe which specific modules are assigned to which remote
+#     # note: RemoteSequenceManager.make_sequence should use load balancing!
+#
+# def RemoteSequential(nn.Module):
+#     def forward(self, inputs: torch.Tensor):
+#         return RemoteSequentialAutogradFunction.apply(inputs, self.sequence_manager(), **self.todo_stuff())
+#     def inference_sesion(self, **stuff):
+#         self.remote_sequence_info.update_()
+#         return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
+
+# class _RemoteSequentialCall(torch.autograd.Function):
+#     """
+#     A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
+
+#     :note: this function splits input data into batches for efficient parallel processing
+#     :note: forward and backward passes may sometimes be served by different modules!
+#     """
+#
+#     def forward(ctx, inputs: torch.Tensor):
+#         input_batches: List[torch.Tensor] = split_into_batches(inputs, MAX_TOKENS_PER_BATCH)
+#         forward_passes: List[concurrent.futures.Future] = []
+#         for input_batch in input_batches:
+#             coro = RemoteExpertWorker.run_coroutine(
+#               async_forward_pass(RemoteSequenceManager, input_batch)), return_future=True
+#             )  # ^-- async_foward_pass does runs RemoteSequenceManager.form_sequence() and runs forward pass in a chain
+#             #    if spans[i] breaks, use RemoteSequenceManager[spans[i].start : spans[i].end].form_sequence() to repair
+#         output_batches = concurrent.futures.wait(forward_passes)
+#         save_intermediate_states(ctx, forward_passes)  # save both your sequence and intermediate states.
+#         # ^-- sequence from forward pass is reused for backward! - and repaired the same way
+#         # [IMPORTANT] maybe first create an op for one batch, then a wrapper that split into batches
+#         return torch.cat(output_batches, dim=0)
+#
+#    def backward(ctx, grad_outputs):
+#         return TODO(ctx, )

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.sequence_manager import RemoteSequenceManager
 from src.client.remote_sequential import RemoteSequential

+ 3 - 3
src/client/remote_sequential.py

@@ -12,7 +12,7 @@ from torch import nn
 
 import src
 from src.client.remote_block import RemoteTransformerBlock
-from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 
@@ -44,7 +44,7 @@ class RemoteSequential(nn.Module):
         block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
 
         logger.debug(f"Remote block uids: {block_uids}")
-        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
+        self.remote_sequence_info = RemoteSequenceManager(dht, block_uids)
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -84,7 +84,7 @@ class RemoteSequential(nn.Module):
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
-    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
+    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
         self.remote_sequence_info = remote_sequence_info
         self.p2p = p2p
         self.closed = False

+ 2 - 2
src/client/remote_sequence_info.py → src/client/sequence_manager.py

@@ -18,8 +18,7 @@ logger = get_logger(__file__)
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
-class RemoteSequenceInfo:
+class RemoteSequenceManager:
     """Keeps and updates the meta-information about which peers host which blocks"""
 
     dht: DHT
@@ -92,3 +91,4 @@ class RemoteSequenceInfo:
 
     def __len__(self):
         return len(self.block_uids)
+