Bladeren bron

rename functions & add descriptions

dbaranchuk 3 jaren geleden
bovenliggende
commit
a2d020afcd
1 gewijzigde bestanden met toevoegingen van 24 en 15 verwijderingen
  1. 24 15
      src/client/sequential_autograd.py

+ 24 - 15
src/client/sequential_autograd.py

@@ -16,11 +16,13 @@ from src.server.handler import TransformerConnectionHandler
 MAX_TOKENS_IN_BATCH = 1024
 
 
-async def run_forward(
+async def run_expert_forward(
     uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
-    TODO: add description
+    Serializes input tensors and calls "expert_forward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
     """
 
     # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
@@ -47,7 +49,7 @@ async def run_forward(
     return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])
 
 
-async def run_backward(
+async def run_expert_backward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
@@ -55,7 +57,9 @@ async def run_backward(
     grad_outputs: List[torch.Tensor],
 ) -> Sequence[torch.Tensor]:
     """
-    TODO: add description
+    Serializes grad outputs and calls "expert_backward".
+    Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
@@ -70,11 +74,13 @@ async def run_backward(
     return deserialized_grad_inputs
 
 
-async def async_forward(
+async def sequential_forward(
     inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
-    TODO: add description
+    Constructs a routing path from <start_index> to <end_index>. 
+    Performs chained forward for each subsequence of blocks on the path.
+    If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
@@ -92,7 +98,7 @@ async def async_forward(
                 span = sequences.pop(0)
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs,) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                (outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -112,14 +118,15 @@ async def async_forward(
     return outputs, intermediate_inputs, done_sequences
 
 
-async def async_backward(
+async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: Sequence[torch.Tensor],
     forward_sequences: Sequence[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
     """
-    TODO: add description
+    Performs chained backward for each forward subsequence.
+    If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
 
     assert len(intermediate_inputs) == len(forward_sequences)
@@ -134,11 +141,11 @@ async def async_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
-                grad_outputs = await run_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
+                grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
-                _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
+                _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
                     inputs, sequence_manager, start_index=span.start, end_index=span.end
                 )
 
@@ -152,13 +159,15 @@ async def async_backward(
 
 
 async def _gather_forward(input_batches, sequence_manager):
-    return await asyncio.gather(*[async_forward(input_batch, sequence_manager) for input_batch in input_batches])
+    """ Wrapper for asyncio.gather to perform parallel sequential forwards """
+    return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
 
 
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
+    """ Wrapper for asyncio.gather to perform parallel sequential backwards """
     return await asyncio.gather(
         *[
-            async_backward((grad_output,), input_batch, spans, sequence_manager)
+            sequential_backward((grad_output,), input_batch, spans, sequence_manager)
             for grad_output, input_batch, spans in zip(
                 grad_output_batches, intermediate_input_batches, forward_sequences
             )
@@ -168,8 +177,8 @@ async def _gather_backward(grad_output_batches, intermediate_input_batches, forw
 
 class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
-    A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
-    :note: this function splits input data into batches for efficient parallel processing
+    PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
+    This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
     """
 
     @staticmethod