Browse Source

add todo & black

dbaranchuk 3 years ago
parent
commit
d94008311d
1 changed files with 9 additions and 6 deletions
  1. 9 6
      src/client/sequential_autograd.py

+ 9 - 6
src/client/sequential_autograd.py

@@ -22,7 +22,7 @@ async def run_expert_forward(
     """
     """
     Serializes input tensors and calls "expert_forward".
     Serializes input tensors and calls "expert_forward".
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
     """
 
 
     # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
     # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
@@ -39,6 +39,7 @@ async def run_expert_forward(
     forward_inputs = nested_flatten(forward_inputs)
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
 
+    # TODO: figure out whether we should use run_in_executor here
     serialized_tensors = (
     serialized_tensors = (
         serialize_torch_tensor(tensor, proto.compression)
         serialize_torch_tensor(tensor, proto.compression)
         for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
         for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
@@ -59,7 +60,7 @@ async def run_expert_backward(
     """
     """
     Serializes grad outputs and calls "expert_backward".
     Serializes grad outputs and calls "expert_backward".
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
     """
 
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
@@ -78,7 +79,7 @@ async def sequential_forward(
     inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
     inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     """
-    Constructs a routing path from <start_index> to <end_index>. 
+    Constructs a routing path from <start_index> to <end_index>.
     Performs chained forward for each subsequence of blocks on the path.
     Performs chained forward for each subsequence of blocks on the path.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
     """
@@ -141,7 +142,9 @@ async def sequential_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
 
-                grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
+                grad_outputs = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                )
                 break
                 break
             except Exception as e:
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
@@ -159,12 +162,12 @@ async def sequential_backward(
 
 
 
 
 async def _gather_forward(input_batches, sequence_manager):
 async def _gather_forward(input_batches, sequence_manager):
-    """ Wrapper for asyncio.gather to perform parallel sequential forwards """
+    """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
     return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
 
 
 
 
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
-    """ Wrapper for asyncio.gather to perform parallel sequential backwards """
+    """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
     return await asyncio.gather(
         *[
         *[
             sequential_backward((grad_output,), input_batch, spans, sequence_manager)
             sequential_backward((grad_output,), input_batch, spans, sequence_manager)