Bläddra i källkod

address the comments

dbaranchuk 3 år sedan
förälder
incheckning
ed86b36911
1 ändrade filer med 29 tillägg och 12 borttagningar
  1. 29 12
      src/client/async_forward_backward.py

+ 29 - 12
src/client/async_forward_backward.py

@@ -37,6 +37,10 @@ async def run_forward(
     *inputs: torch.Tensor,
     **kwargs
 ) -> Tuple[torch.Tensor, ...]:
+    """
+    TODO: add description
+    """
+
     # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
     # detach to avoid pickling the computation graph
     assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
@@ -68,6 +72,9 @@ async def run_backward(
     intemediate_inputs: List[torch.Tensor], 
     grad_outputs: List[torch.Tensor], 
 ) -> Sequence[torch.Tensor]:
+    """
+    TODO: add description
+    """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
@@ -83,11 +90,20 @@ async def run_backward(
 
 async def async_forward(
     inputs: torch.Tensor, 
-    sequence_manager: RemoteSequenceManager
-    ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+    sequence_manager: RemoteSequenceManager,
+    start_index: int = 0, 
+    end_index: Optional[int] = None
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
+    """
+    TODO: add description
+    """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
-    sequences = sequence_manager.make_sequence()
+
+    end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
+    assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
+
+    sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
     done_sequences = []
 
@@ -109,11 +125,10 @@ async def async_forward(
                 inputs = outputs
                 break
             except Exception as e:
-                logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
-                backup_sequences = sequence_manager[span.start: span.end].make_sequence()
+                logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
-                assert backup_sequences[-1].end == span.end
-                sequences = backup_sequences + sequences[1:]
+                sequences = backup_sequences
 
     return outputs, intermediate_inputs, done_sequences
 
@@ -124,6 +139,9 @@ async def async_backward(
     forward_sequences: Sequence[RemoteSpanInfo], 
     sequence_manager: RemoteSequenceManager
 ) -> Sequence[torch.Tensor]:
+    """
+    TODO: add description
+    """
 
     assert len(intermediate_inputs) == len(forward_sequences)
     # TODO think about grads w.r.t. deep prompts
@@ -144,15 +162,15 @@ async def async_backward(
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 _, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
-                    inputs, sequence_manager[span.start: span.end] # TODO: new sequence manager requires new rpc_info init and hence freezes
+                    inputs, sequence_manager, start_index=span.start, end_index=span.end 
                 )
 
-                forward_sequences = forward_sequences + backup_forward_sequences
-                intermediate_inputs = intermediate_inputs + backup_intermediate_inputs
-
                 assert len(intermediate_inputs) == len(forward_sequences)
                 assert backup_forward_sequences[0].start == span.start
                 assert backup_forward_sequences[-1].end == span.end
+
+                forward_sequences.extend(backup_forward_sequences)
+                intermediate_inputs.extend(backup_intermediate_inputs)
     return grad_outputs
 
 
@@ -208,7 +226,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
 
         grad_input_batches = RemoteExpertWorker.run_coroutine(
             _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
-            # async_backward((grad_output_batches[0], ), intermediate_input_batches[0], forward_sequences[0], ctx.sequence_manager)
         )
         grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
         grad_inputs = torch.cat(grad_inputs, dim=0)