Browse Source

add docstr

Your Name 2 years ago
parent
commit
2e760319ab

+ 1 - 1
src/petals/client/inference_session.py

@@ -93,7 +93,7 @@ class _ServerInferenceSession:
     ) -> torch.Tensor:
     ) -> torch.Tensor:
         """
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
         Inference step: send a chunk of input tensors and receive a chunk of outputs
-        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+        :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
           if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
           if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
         """
         """
         if self.closed:
         if self.closed:

+ 10 - 22
src/petals/client/remote_forward_backward.py

@@ -19,30 +19,30 @@ from petals.data_structures import ModuleUID, RPCInfo
 
 
 
 
 async def _forward_unary(
 async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
         timeout=config.request_timeout,
         timeout=config.request_timeout,
     )
     )
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 
 
 async def _backward_unary(
 async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
         timeout=config.request_timeout,
         timeout=config.request_timeout,
     )
     )
     return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
     return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
 
 
 
 
 async def _forward_stream(
 async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
     parts = (
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
         for tensor in serialized_tensors
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
     )
@@ -52,10 +52,10 @@ async def _forward_stream(
 
 
 
 
 async def _backward_stream(
 async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
 ) -> List[torch.Tensor]:
     parts = (
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
         for tensor in serialized_tensors
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
     )
@@ -68,31 +68,19 @@ async def run_remote_forward(
     uid: ModuleUID,
     uid: ModuleUID,
     stub: StubBase,
     stub: StubBase,
     rpc_info: RPCInfo,
     rpc_info: RPCInfo,
-    *inputs: torch.Tensor,
+    *forward_inputs: torch.Tensor,
     config: ClientConfig,
     config: ClientConfig,
     metadata: Optional[bytes] = None,
     metadata: Optional[bytes] = None,
-    **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
 ) -> Tuple[torch.Tensor, ...]:
     """
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
     Serializes input tensors and calls "rpc_forward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
     """
-
-    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-    # detach to avoid pickling the computation graph
-    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
-    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-
-    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-    forward_inputs = tuple(nested_flatten((inputs, kwargs)))
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     compression = args_schema[0].compression
     compression = args_schema[0].compression
     forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
     forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
-    # TODO: create more explicit way to check servers schema and client's structure
-    assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
-
     # Asynchronous serialization
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
     serialized_tensors = await asyncio.gather(
@@ -106,7 +94,7 @@ async def run_remote_forward(
     size = sum(t.element_size() * t.nelement() for t in inputs)
     size = sum(t.element_size() * t.nelement() for t in inputs)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
+    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
 
 

+ 14 - 2
src/petals/client/sequential_autograd.py

@@ -4,7 +4,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 import asyncio
 import asyncio
 import itertools
 import itertools
 from collections import deque
 from collections import deque
-from typing import List, Optional, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple, Dict, Any
 
 
 import torch
 import torch
 from hivemind import MSGPackSerializer
 from hivemind import MSGPackSerializer
@@ -29,14 +29,25 @@ async def sequential_forward(
     sequence_manager: RemoteSequenceManager,
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     start_index: int = 0,
     end_index: Optional[int] = None,
     end_index: Optional[int] = None,
+    block_kwargs: Sequence[Dict[str, Any]] = (),
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     """
     Constructs a routing path from <start_index> to <end_index>.
     Constructs a routing path from <start_index> to <end_index>.
     Performs chained forward for each subsequence of blocks on the path.
     Performs chained forward for each subsequence of blocks on the path.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+
+    :param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size]
+    :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
+    :param sequence_manager: a running SequenceManager used to select remote servers and handle failures
+    :param start_index: run remote blocks starting from this index
+    :param end_index: run remote blocks up to (but not including) this index
+    :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
     """
     """
 
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+    assert len(block_kwargs) in (0, 1, end_index - start_index), \
+        f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
 
 
     inputs_device = inputs.device
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
     inputs_dtype = inputs.dtype
@@ -68,7 +79,8 @@ async def sequential_forward(
                 span = sequences.popleft()
                 span = sequences.popleft()
 
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
-                flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
+                flat_tensors, args_structure = pack_args_kwargs(
+                    inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end])
 
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 metadata = sequence_manager.get_request_metadata(
                 metadata = sequence_manager.get_request_metadata(