dbaranchuk 3 лет назад
Родитель
Сommit
c925f4d45b
1 измененных файлов с 37 добавлено и 49 удалено
  1. 37 49
      src/client/sequential_autograd.py

+ 37 - 49
src/client/sequential_autograd.py

@@ -1,29 +1,23 @@
+import asyncio
 import logging
 import logging
-from typing import Optional, List, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple
 
 
 import torch
 import torch
-import asyncio
-
 from hivemind import serialize_torch_tensor
 from hivemind import serialize_torch_tensor
-from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
-from hivemind.moe.client.expert import expert_forward, expert_backward
+from hivemind.moe.client.expert import expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.p2p import StubBase
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
-from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo, ModuleUID, RemoteSpanInfo, RPCInfo
 
 
-
-MAX_TOKENS_IN_BATCH=1024
+MAX_TOKENS_IN_BATCH = 1024
 
 
 
 
 async def run_forward(
 async def run_forward(
-    uid: ModuleUID, 
-    stub: StubBase,
-    rpc_info: RPCInfo,
-    *inputs: torch.Tensor,
-    **kwargs
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
 ) -> Tuple[torch.Tensor, ...]:
     """
     """
     TODO: add description
     TODO: add description
@@ -33,13 +27,13 @@ async def run_forward(
     # detach to avoid pickling the computation graph
     # detach to avoid pickling the computation graph
     assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
     assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
     kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
     kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-    
+
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
     forward_inputs = (inputs, kwargs)
 
 
     if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
     if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
- 
+
     forward_inputs = nested_flatten(forward_inputs)
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
 
@@ -54,11 +48,11 @@ async def run_forward(
 
 
 
 
 async def run_backward(
 async def run_backward(
-    uid: ModuleUID, 
+    uid: ModuleUID,
     stub: StubBase,
     stub: StubBase,
     rpc_info: RPCInfo,
     rpc_info: RPCInfo,
-    intemediate_inputs: List[torch.Tensor], 
-    grad_outputs: List[torch.Tensor], 
+    intemediate_inputs: List[torch.Tensor],
+    grad_outputs: List[torch.Tensor],
 ) -> Sequence[torch.Tensor]:
 ) -> Sequence[torch.Tensor]:
     """
     """
     TODO: add description
     TODO: add description
@@ -77,10 +71,7 @@ async def run_backward(
 
 
 
 
 async def async_forward(
 async def async_forward(
-    inputs: torch.Tensor, 
-    sequence_manager: RemoteSequenceManager,
-    start_index: int = 0, 
-    end_index: Optional[int] = None
+    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     """
     TODO: add description
     TODO: add description
@@ -99,9 +90,9 @@ async def async_forward(
         while True:
         while True:
             try:
             try:
                 span = sequences.pop(0)
                 span = sequences.pop(0)
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs, ) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                (outputs,) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
 
 
                 assert isinstance(outputs, torch.Tensor)
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -123,9 +114,9 @@ async def async_forward(
 
 
 async def async_backward(
 async def async_backward(
     grad_outputs: Sequence[torch.Tensor],
     grad_outputs: Sequence[torch.Tensor],
-    intermediate_inputs: Sequence[torch.Tensor],  
-    forward_sequences: Sequence[RemoteSpanInfo], 
-    sequence_manager: RemoteSequenceManager
+    intermediate_inputs: Sequence[torch.Tensor],
+    forward_sequences: Sequence[RemoteSpanInfo],
+    sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
 ) -> Sequence[torch.Tensor]:
     """
     """
     TODO: add description
     TODO: add description
@@ -133,24 +124,22 @@ async def async_backward(
 
 
     assert len(intermediate_inputs) == len(forward_sequences)
     assert len(intermediate_inputs) == len(forward_sequences)
     # TODO think about grads w.r.t. deep prompts
     # TODO think about grads w.r.t. deep prompts
-    
+
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
         while True:
             try:
             try:
                 inputs = intermediate_inputs.pop(-1)
                 inputs = intermediate_inputs.pop(-1)
                 span = forward_sequences.pop(-1)
                 span = forward_sequences.pop(-1)
 
 
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                
-                grad_outputs = await run_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
-                )
+
+                grad_outputs = await run_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
                 break
                 break
             except Exception as e:
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
                 _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
-                    inputs, sequence_manager, start_index=span.start, end_index=span.end 
+                    inputs, sequence_manager, start_index=span.start, end_index=span.end
                 )
                 )
 
 
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert len(intermediate_inputs) == len(forward_sequences)
@@ -163,17 +152,18 @@ async def async_backward(
 
 
 
 
 async def _gather_forward(input_batches, sequence_manager):
 async def _gather_forward(input_batches, sequence_manager):
-    return await asyncio.gather(*[
-        async_forward(input_batch, sequence_manager)
-        for input_batch in input_batches
-    ])
+    return await asyncio.gather(*[async_forward(input_batch, sequence_manager) for input_batch in input_batches])
 
 
 
 
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
-    return await asyncio.gather(*[
-        async_backward((grad_output, ), input_batch, spans, sequence_manager)
-        for grad_output, input_batch, spans in zip(grad_output_batches, intermediate_input_batches, forward_sequences)
-    ])
+    return await asyncio.gather(
+        *[
+            async_backward((grad_output,), input_batch, spans, sequence_manager)
+            for grad_output, input_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, forward_sequences
+            )
+        ]
+    )
 
 
 
 
 class _RemoteSequentialAutogradFunction(torch.autograd.Function):
 class _RemoteSequentialAutogradFunction(torch.autograd.Function):
@@ -181,16 +171,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
     A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
     :note: this function splits input data into batches for efficient parallel processing
     :note: this function splits input data into batches for efficient parallel processing
     """
     """
- 
+
     @staticmethod
     @staticmethod
     def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
     def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
         input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
 
 
-        sequence_manager.rpc_info # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(
-            _gather_forward(input_batches, sequence_manager)
-        )
+        sequence_manager.rpc_info  # lazy init
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
         assert len(outputs) == len(input_batches)
 
 
         output_batches = [output[0] for output in outputs]
         output_batches = [output[0] for output in outputs]
@@ -201,12 +189,12 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.sequences_for_batches = sequences_for_batches
         ctx.sequences_for_batches = sequences_for_batches
         return torch.cat(output_batches, dim=0)
         return torch.cat(output_batches, dim=0)
- 
+
     @staticmethod
     @staticmethod
     def backward(ctx, grad_outputs: torch.Tensor):
     def backward(ctx, grad_outputs: torch.Tensor):
         intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
         intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
         forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
         forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
-        ctx.sequence_manager.rpc_info # lazy init
+        ctx.sequence_manager.rpc_info  # lazy init
 
 
         batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
         batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)