Parcourir la source

black & isort

dbaranchuk il y a 3 ans
Parent
commit
c925f4d45b
1 fichiers modifiés avec 37 ajouts et 49 suppressions
  1. 37 49
      src/client/sequential_autograd.py

+ 37 - 49
src/client/sequential_autograd.py

@@ -1,29 +1,23 @@
+import asyncio
 import logging
-from typing import Optional, List, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple
 
 import torch
-import asyncio
-
 from hivemind import serialize_torch_tensor
-from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
-from hivemind.moe.client.expert import expert_forward, expert_backward
+from hivemind.moe.client.expert import expert_backward, expert_forward
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
+from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
 
 from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.server.handler import TransformerConnectionHandler
-from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo, ModuleUID, RemoteSpanInfo, RPCInfo
 
-
-MAX_TOKENS_IN_BATCH=1024
+MAX_TOKENS_IN_BATCH = 1024
 
 
 async def run_forward(
-    uid: ModuleUID, 
-    stub: StubBase,
-    rpc_info: RPCInfo,
-    *inputs: torch.Tensor,
-    **kwargs
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
     TODO: add description
@@ -33,13 +27,13 @@ async def run_forward(
     # detach to avoid pickling the computation graph
     assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
     kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-    
+
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
 
     if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
- 
+
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
@@ -54,11 +48,11 @@ async def run_forward(
 
 
 async def run_backward(
-    uid: ModuleUID, 
+    uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    intemediate_inputs: List[torch.Tensor], 
-    grad_outputs: List[torch.Tensor], 
+    intemediate_inputs: List[torch.Tensor],
+    grad_outputs: List[torch.Tensor],
 ) -> Sequence[torch.Tensor]:
     """
     TODO: add description
@@ -77,10 +71,7 @@ async def run_backward(
 
 
 async def async_forward(
-    inputs: torch.Tensor, 
-    sequence_manager: RemoteSequenceManager,
-    start_index: int = 0, 
-    end_index: Optional[int] = None
+    inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     TODO: add description
@@ -99,9 +90,9 @@ async def async_forward(
         while True:
             try:
                 span = sequences.pop(0)
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                (outputs, ) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
+                (outputs,) = await run_forward(span_uids, stub, sequence_manager.rpc_info, inputs)
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -123,9 +114,9 @@ async def async_forward(
 
 async def async_backward(
     grad_outputs: Sequence[torch.Tensor],
-    intermediate_inputs: Sequence[torch.Tensor],  
-    forward_sequences: Sequence[RemoteSpanInfo], 
-    sequence_manager: RemoteSequenceManager
+    intermediate_inputs: Sequence[torch.Tensor],
+    forward_sequences: Sequence[RemoteSpanInfo],
+    sequence_manager: RemoteSequenceManager,
 ) -> Sequence[torch.Tensor]:
     """
     TODO: add description
@@ -133,24 +124,22 @@ async def async_backward(
 
     assert len(intermediate_inputs) == len(forward_sequences)
     # TODO think about grads w.r.t. deep prompts
-    
+
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         while True:
             try:
                 inputs = intermediate_inputs.pop(-1)
                 span = forward_sequences.pop(-1)
 
-                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start: span.end])
+                span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                
-                grad_outputs = await run_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
-                )
+
+                grad_outputs = await run_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
-                    inputs, sequence_manager, start_index=span.start, end_index=span.end 
+                    inputs, sequence_manager, start_index=span.start, end_index=span.end
                 )
 
                 assert len(intermediate_inputs) == len(forward_sequences)
@@ -163,17 +152,18 @@ async def async_backward(
 
 
 async def _gather_forward(input_batches, sequence_manager):
-    return await asyncio.gather(*[
-        async_forward(input_batch, sequence_manager)
-        for input_batch in input_batches
-    ])
+    return await asyncio.gather(*[async_forward(input_batch, sequence_manager) for input_batch in input_batches])
 
 
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
-    return await asyncio.gather(*[
-        async_backward((grad_output, ), input_batch, spans, sequence_manager)
-        for grad_output, input_batch, spans in zip(grad_output_batches, intermediate_input_batches, forward_sequences)
-    ])
+    return await asyncio.gather(
+        *[
+            async_backward((grad_output,), input_batch, spans, sequence_manager)
+            for grad_output, input_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, forward_sequences
+            )
+        ]
+    )
 
 
 class _RemoteSequentialAutogradFunction(torch.autograd.Function):
@@ -181,16 +171,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     A pytorch autograd-compatible function that calls a sequence of transformer blocks on remote peers
     :note: this function splits input data into batches for efficient parallel processing
     """
- 
+
     @staticmethod
     def forward(ctx, inputs: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.split(batch_size)
 
-        sequence_manager.rpc_info # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(
-            _gather_forward(input_batches, sequence_manager)
-        )
+        sequence_manager.rpc_info  # lazy init
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 
         output_batches = [output[0] for output in outputs]
@@ -201,12 +189,12 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.sequences_for_batches = sequences_for_batches
         return torch.cat(output_batches, dim=0)
- 
+
     @staticmethod
     def backward(ctx, grad_outputs: torch.Tensor):
         intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
         forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
-        ctx.sequence_manager.rpc_info # lazy init
+        ctx.sequence_manager.rpc_info  # lazy init
 
         batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)