|
@@ -37,6 +37,10 @@ async def run_forward(
|
|
|
*inputs: torch.Tensor,
|
|
|
**kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
+ """
|
|
|
+ TODO: add description
|
|
|
+ """
|
|
|
+
|
|
|
# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
|
|
|
# detach to avoid pickling the computation graph
|
|
|
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
|
|
@@ -68,6 +72,9 @@ async def run_backward(
|
|
|
intemediate_inputs: List[torch.Tensor],
|
|
|
grad_outputs: List[torch.Tensor],
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
+ """
|
|
|
+ TODO: add description
|
|
|
+ """
|
|
|
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((intemediate_inputs, grad_outputs_cpu)))
|
|
@@ -83,11 +90,20 @@ async def run_backward(
|
|
|
|
|
|
async def async_forward(
|
|
|
inputs: torch.Tensor,
|
|
|
- sequence_manager: RemoteSequenceManager
|
|
|
- ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
|
+ sequence_manager: RemoteSequenceManager,
|
|
|
+ start_index: int = 0,
|
|
|
+ end_index: Optional[int] = None
|
|
|
+) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
|
|
|
+ """
|
|
|
+ TODO: add description
|
|
|
+ """
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
|
|
|
- sequences = sequence_manager.make_sequence()
|
|
|
+
|
|
|
+ end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
+ assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
+
|
|
|
+ sequences = sequence_manager.make_sequence(start_index, end_index)
|
|
|
intermediate_inputs = []
|
|
|
done_sequences = []
|
|
|
|
|
@@ -109,11 +125,10 @@ async def async_forward(
|
|
|
inputs = outputs
|
|
|
break
|
|
|
except Exception as e:
|
|
|
- logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
- backup_sequences = sequence_manager[span.start: span.end].make_sequence()
|
|
|
+ logging.debug(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
+ backup_sequences = sequence_manager.make_sequence(span.start)
|
|
|
assert backup_sequences[0].start == span.start
|
|
|
- assert backup_sequences[-1].end == span.end
|
|
|
- sequences = backup_sequences + sequences[1:]
|
|
|
+ sequences = backup_sequences
|
|
|
|
|
|
return outputs, intermediate_inputs, done_sequences
|
|
|
|
|
@@ -124,6 +139,9 @@ async def async_backward(
|
|
|
forward_sequences: Sequence[RemoteSpanInfo],
|
|
|
sequence_manager: RemoteSequenceManager
|
|
|
) -> Sequence[torch.Tensor]:
|
|
|
+ """
|
|
|
+ TODO: add description
|
|
|
+ """
|
|
|
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
# TODO think about grads w.r.t. deep prompts
|
|
@@ -144,15 +162,15 @@ async def async_backward(
|
|
|
except Exception as e:
|
|
|
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await async_forward(
|
|
|
- inputs, sequence_manager[span.start: span.end] # TODO: new sequence manager requires new rpc_info init and hence freezes
|
|
|
+ inputs, sequence_manager, start_index=span.start, end_index=span.end
|
|
|
)
|
|
|
|
|
|
- forward_sequences = forward_sequences + backup_forward_sequences
|
|
|
- intermediate_inputs = intermediate_inputs + backup_intermediate_inputs
|
|
|
-
|
|
|
assert len(intermediate_inputs) == len(forward_sequences)
|
|
|
assert backup_forward_sequences[0].start == span.start
|
|
|
assert backup_forward_sequences[-1].end == span.end
|
|
|
+
|
|
|
+ forward_sequences.extend(backup_forward_sequences)
|
|
|
+ intermediate_inputs.extend(backup_intermediate_inputs)
|
|
|
return grad_outputs
|
|
|
|
|
|
|
|
@@ -208,7 +226,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
|
|
|
|
|
|
grad_input_batches = RemoteExpertWorker.run_coroutine(
|
|
|
_gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, ctx.sequence_manager)
|
|
|
- # async_backward((grad_output_batches[0], ), intermediate_input_batches[0], forward_sequences[0], ctx.sequence_manager)
|
|
|
)
|
|
|
grad_inputs = [grad_input_batch[0] for grad_input_batch in grad_input_batches]
|
|
|
grad_inputs = torch.cat(grad_inputs, dim=0)
|