瀏覽代碼

add todo & black

dbaranchuk 3 年之前
父節點
當前提交
d94008311d
共有 1 個文件被更改,包括 9 次插入6 次删除
  1. 9 6
      src/client/sequential_autograd.py

+ 9 - 6
src/client/sequential_autograd.py

@@ -22,7 +22,7 @@ async def run_expert_forward(
     """
     Serializes input tensors and calls "expert_forward".
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
 
     # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
@@ -39,6 +39,7 @@ async def run_expert_forward(
     forward_inputs = nested_flatten(forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
 
+    # TODO: figure out whether we should use run_in_executor here
     serialized_tensors = (
         serialize_torch_tensor(tensor, proto.compression)
         for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
@@ -59,7 +60,7 @@ async def run_expert_backward(
     """
     Serializes grad outputs and calls "expert_backward".
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
-    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.  
+    but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
@@ -78,7 +79,7 @@ async def sequential_forward(
     inputs: torch.Tensor, sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
-    Constructs a routing path from <start_index> to <end_index>. 
+    Constructs a routing path from <start_index> to <end_index>.
     Performs chained forward for each subsequence of blocks on the path.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
@@ -141,7 +142,9 @@ async def sequential_backward(
                 span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
 
-                grad_outputs = await run_expert_backward(span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs)
+                grad_outputs = await run_expert_backward(
+                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs
+                )
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
@@ -159,12 +162,12 @@ async def sequential_backward(
 
 
 async def _gather_forward(input_batches, sequence_manager):
-    """ Wrapper for asyncio.gather to perform parallel sequential forwards """
+    """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     return await asyncio.gather(*[sequential_forward(input_batch, sequence_manager) for input_batch in input_batches])
 
 
 async def _gather_backward(grad_output_batches, intermediate_input_batches, forward_sequences, sequence_manager):
-    """ Wrapper for asyncio.gather to perform parallel sequential backwards """
+    """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
             sequential_backward((grad_output,), input_batch, spans, sequence_manager)