Your Name 1 vuosi sitten
vanhempi
commit
2e760319ab

+ 1 - 1
src/petals/client/inference_session.py

@@ -93,7 +93,7 @@ class _ServerInferenceSession:
     ) -> torch.Tensor:
         """
         Inference step: send a chunk of input tensors and receive a chunk of outputs
-        :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+        :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
           if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
         """
         if self.closed:

+ 10 - 22
src/petals/client/remote_forward_backward.py

@@ -19,30 +19,30 @@ from petals.data_structures import ModuleUID, RPCInfo
 
 
 async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
         timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)),
         timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
 
 
 async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
@@ -52,10 +52,10 @@ async def _forward_stream(
 
 
 async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig
 ) -> List[torch.Tensor]:
     parts = (
-        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part])
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
@@ -68,31 +68,19 @@ async def run_remote_forward(
     uid: ModuleUID,
     stub: StubBase,
     rpc_info: RPCInfo,
-    *inputs: torch.Tensor,
+    *forward_inputs: torch.Tensor,
     config: ClientConfig,
     metadata: Optional[bytes] = None,
-    **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
-
-    # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
-    # detach to avoid pickling the computation graph
-    assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
-    kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
-
-    # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
-    forward_inputs = tuple(nested_flatten((inputs, kwargs)))
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     compression = args_schema[0].compression
     forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
     inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
-    # TODO: create more explicit way to check servers schema and client's structure
-    assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step"
-
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
@@ -106,7 +94,7 @@ async def run_remote_forward(
     size = sum(t.element_size() * t.nelement() for t in inputs)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
+    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 

+ 14 - 2
src/petals/client/sequential_autograd.py

@@ -4,7 +4,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 import asyncio
 import itertools
 from collections import deque
-from typing import List, Optional, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple, Dict, Any
 
 import torch
 from hivemind import MSGPackSerializer
@@ -29,14 +29,25 @@ async def sequential_forward(
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     end_index: Optional[int] = None,
+    block_kwargs: Sequence[Dict[str, Any]] = (),
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
     Performs chained forward for each subsequence of blocks on the path.
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
+
+    :param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size]
+    :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
+          if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
+    :param sequence_manager: a running SequenceManager used to select remote servers and handle failures
+    :param start_index: run remote blocks starting from this index
+    :param end_index: run remote blocks up to (but not including) this index
+    :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+    assert len(block_kwargs) in (0, 1, end_index - start_index), \
+        f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
 
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
@@ -68,7 +79,8 @@ async def sequential_forward(
                 span = sequences.popleft()
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
-                flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end])
+                flat_tensors, args_structure = pack_args_kwargs(
+                    inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end])
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 metadata = sequence_manager.get_request_metadata(