瀏覽代碼

WIP BEFORE MEETING NEED BACKWARD UPDATE

Your Name 1 年之前
父節點
當前提交
e5c2d8eca4

+ 40 - 36
src/petals/client/remote_forward_backward.py

@@ -5,7 +5,7 @@ import asyncio
 from typing import Iterable, List, Optional, Sequence, Tuple
 
 import torch
-from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
+from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID
 from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
 from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
@@ -14,8 +14,11 @@ from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 from hivemind.utils.tensor_descr import BatchTensorDescriptor
 
+from petals import RemoteSequenceManager
 from petals.client.config import ClientConfig
-from petals.data_structures import ModuleUID, RPCInfo
+from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER
+from petals.server.handler import TransformerConnectionHandler
+from petals.utils.packaging import pack_args_kwargs
 
 
 async def _forward_unary(
@@ -65,73 +68,74 @@ async def _backward_stream(
 
 
 async def run_remote_forward(
-    uid: ModuleUID,
-    stub: StubBase,
-    rpc_info: RPCInfo,
-    *forward_inputs: torch.Tensor,
-    config: ClientConfig,
-    metadata: Optional[bytes] = None,
+    sequence_manager: RemoteSequenceManager,
+    peer_id: PeerID,
+    span_uids: Sequence[ModuleUID],
+    *args: torch.Tensor,
+    **kwargs: torch.Tensor,
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    compression = args_schema[0].compression
-    forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs)
-    inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
+    merged_uid = CHAIN_DELIMITER.join(span_uids)
+    stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)
+    flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs)
+    metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs)
+    compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs)
+    if compressions is None:
+        compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs)
+    compressions = list(nested_flatten(compressions))
+    assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs"
+    inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs)
+
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
         *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, forward_schema)
+            loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)
+            for tensor, compression in zip(inputs, compressions)
         )
     )
 
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
-    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata)
-    return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
+    # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR
+    return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata)
 
 
 async def run_remote_backward(
-    uid: ModuleUID,
+    sequence_manager: RemoteSequenceManager,
+    span_uids: Sequence[ModuleUID],
     stub: StubBase,
-    rpc_info: RPCInfo,
-    *inputs_and_grad_outputs: torch.Tensor,
-    config: ClientConfig,
-    metadata: Optional[bytes] = None,
-    **kwargs,
+    grad_outputs: Sequence[torch.Tensor],
+    *args: torch.Tensor,
+    **kwargs: torch.Tensor,
 ) -> Sequence[torch.Tensor]:
     """
     Serializes grad outputs and calls "rpc_backward" on a remote server.
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
-    args_schema, kwargs_schema = rpc_info["forward_schema"]
-    outputs_schema = rpc_info["outputs_schema"]
-    compression = args_schema[0].compression
-    backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs)
-    # TODO: create more explicit way to check servers schema and client's structure
-    assert (
-        len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1
-    ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step"
+    flat_tensors, args_structure = pack_args_kwargs(
+        [grad.cpu() for grad in grad_outputs], args, kwargs
+    )
+    metadata = sequence_manager.get_request_metadata(
+        "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
+    )
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     serialized_tensors = await asyncio.gather(
         *(
-            loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
+            loop.run_in_executor(None, serialize_torch_tensor, compression)
+            for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema)
         )
     )
 
-    size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+    size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
-    return deserialized_grad_inputs
+    return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata)

+ 16 - 8
src/petals/client/routing/sequence_manager.py

@@ -474,22 +474,30 @@ class RemoteSequenceManager:
             return 0
         return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
 
-    def get_request_metadata(
-        self, protocol: str, args_structure: Any = None, *args, **kwargs
-    ) -> Optional[Dict[str, Any]]:
+    def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]:
         """
+        :param peer_id: remote server's PeerID
         :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
-        :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging
-        :param args: request-specific inputs, typically block uids and input tensors
-        :param kwargs: additional request context, such as remote peer ID
-        :returns: msgpack-serialized metadata dict that will be passed alongside a given request
+        :param args: request-specific input tensors
+        :param kwargs: additional request keyword arguments
+        :returns: metadata dict that will be passed alongside a given request
         """
         return dict(
             points=self.policy.get_points(protocol, *args, **kwargs),
             active_adapter=self.config.active_adapter,
-            args_structure=args_structure,
         )
 
+    def get_compression_codecs(
+            self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]:
+        """
+        :param peer_id: remote server's PeerID
+        :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
+        :param args: request-specific input tensors
+        :param kwargs: additional request keyword arguments
+        :returns: compressions for each input tensor; contains as many elements as there are tensors in (args, kwargs)
+        """
+        return None
+
     def shutdown(self):
         self._thread.shutdown()
 

+ 15 - 29
src/petals/client/sequential_autograd.py

@@ -46,8 +46,11 @@ async def sequential_forward(
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
-    assert len(block_kwargs) in (0, 1, end_index - start_index), \
-        f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
+    assert len(block_kwargs) in (
+        0,
+        1,
+        end_index - start_index,
+    ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
 
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
@@ -78,27 +81,19 @@ async def sequential_forward(
 
                 span = sequences.popleft()
 
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
-                flat_tensors, args_structure = pack_args_kwargs(
-                    inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end])
-
-                span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                metadata = sequence_manager.get_request_metadata(
-                    "rpc_forward", args_structure, span_uids, *flat_tensors
-                )
                 (outputs,) = await run_remote_forward(
-                    span_uids,
-                    stub,
-                    sequence_manager.rpc_info,
-                    *flat_tensors,
-                    config=sequence_manager.config,
-                    metadata=MSGPackSerializer.dumps(metadata),
+                    sequence_manager,
+                    span.peer_id,
+                    sequence_manager.block_uids[span.start : span.end],
+                    inputs,
+                    prompts[span.start : span.end],
+                    *block_kwargs[span.start : span.end]
                 )
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
 
-                # Save intermediate inputs and subsequences if the forward is already done for them
+                # Save intermediate inputs and subsequ_peerences if the forward is already done for them
                 intermediate_inputs.append(inputs)
                 done_sequences.append(span)
 
@@ -164,23 +159,14 @@ async def sequential_backward(
                     inputs = intermediate_inputs.pop()
                     span = forward_sequences.pop()
 
-                grad_outputs_cpu = [grad.cpu() for grad in grad_outputs]
-                flat_tensors, args_structure = pack_args_kwargs(
-                    inputs, *grad_outputs_cpu, prompts[span.start : span.end]
-                )
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
-                metadata = sequence_manager.get_request_metadata(
-                    "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id
-                )
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
+                    sequence_manager,
+                    sequence_manager.block_uids[span.start: span.end],
                     span_uids,
-                    stub,
-                    sequence_manager.rpc_info,
-                    *flat_tensors,
-                    config=sequence_manager.config,
-                    metadata=MSGPackSerializer.dumps(metadata),
+                    grad_outputs, inputs,
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)