Преглед изворни кода

Fix dtypes in backend schemas (#99)

Currently, the schemas use `torch.float32`, so all inputs and outputs converted to float32 before sending and after receiving on both servers and clients. This creates a huge slowdown for the system.

* This PR makes the schemas use the server's `--torch_dtype` argument (default is `torch.bloat16` for BLOOM-176B)
* an option for client to request a specific output compression. Use case 1: client sends quantized inputs and expects quantized inputs in return. Use case 2: client uses quantization for gradients w.r.t. activations, but keeps grads w.r.t. __prompts__ as is for greater precision.
* a comment explaining the purpose of NoSpendingPolicy - since we likely won't have it for the workshop
* a test with custom compression (janky implementation for testing purposes)

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov пре 2 година
родитељ
комит
43ac6016ac

+ 13 - 12
src/petals/client/remote_forward_backward.py

@@ -2,7 +2,7 @@
 Utility functions that call RPC forward or backward on a single remote server
 """
 import asyncio
-from typing import Iterable, List, Sequence, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
@@ -63,7 +63,13 @@ async def _backward_stream(
 
 
 async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
+    uid: ModuleUID,
+    stub: StubBase,
+    rpc_info: RPCInfo,
+    *inputs: torch.Tensor,
+    timeout: float,
+    metadata: Optional[bytes] = None,
+    **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
@@ -102,11 +108,8 @@ async def run_remote_forward(
 
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
-    if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
-    else:
-        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
-
+    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
+    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -118,6 +121,7 @@ async def run_remote_backward(
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
     timeout: float,
+    metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Sequence[torch.Tensor]:
     """
@@ -146,9 +150,6 @@ async def run_remote_backward(
     )
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
-    if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
-    else:
-        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
-
+    backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
+    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return deserialized_grad_inputs

+ 11 - 0
src/petals/client/sequence_manager.py

@@ -10,6 +10,7 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 import petals.dht_utils
+from petals.client.spending_policy import NoSpendingPolicy
 from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 
@@ -43,6 +44,7 @@ class RemoteSequenceManager:
         self.timeout, self.min_backoff = timeout, min_backoff
         self._rpc_info = None
         self.lock_changes = threading.Lock()
+        self.policy = NoSpendingPolicy()
         self.update_()
 
         for uid, info in zip(self.block_uids, self.block_infos):
@@ -166,3 +168,12 @@ class RemoteSequenceManager:
         if attempt_no == 0:
             return 0
         return self.min_backoff * 2 ** (attempt_no - 1)
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
+        """
+        :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
+        :param args: request-specific inputs, typicall block uids and input tensors
+        :param kwargs: additional request context, such as remote peer ID
+        :returns: msgpack-serialized metadata dict that will be passed alongside a given request
+        """
+        return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))

+ 11 - 1
src/petals/client/sequential_autograd.py

@@ -72,8 +72,14 @@ async def sequential_forward(
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
                 (outputs,) = await run_remote_forward(
-                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
+                    span_uids,
+                    stub,
+                    sequence_manager.rpc_info,
+                    *inputs_and_prompts,
+                    timeout=sequence_manager.timeout,
+                    metadata=metadata,
                 )
 
                 assert isinstance(outputs, torch.Tensor)
@@ -146,6 +152,9 @@ async def sequential_backward(
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                metadata = sequence_manager.get_request_metadata(
+                    "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
+                )
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
                     stub,
@@ -154,6 +163,7 @@ async def sequential_backward(
                     grad_outputs,
                     prompts[span.start : span.end],
                     timeout=sequence_manager.timeout,
+                    metadata=metadata,
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)

+ 7 - 4
src/petals/client/spending_policy.py

@@ -1,14 +1,17 @@
+"""
+An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED.
+The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these
+ points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion.
+"""
 from abc import ABC, abstractmethod
 
-from hivemind.proto.runtime_pb2 import ExpertRequest
-
 
 class SpendingPolicyBase(ABC):
     @abstractmethod
-    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+    def get_points(self, protocol: str, *args, **kwargs) -> float:
         pass
 
 
 class NoSpendingPolicy(SpendingPolicyBase):
-    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+    def get_points(self, protocol: str, *args, **kwargs) -> float:
         return 0.0

+ 4 - 2
src/petals/server/backend.py

@@ -18,7 +18,7 @@ logger = get_logger(__file__)
 class TransformerBackend(ModuleBackend):
     """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
 
-    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
+    def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
@@ -37,7 +37,9 @@ class TransformerBackend(ModuleBackend):
         self.backward_pool = PrioritizedTaskPool(
             self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
         )
-        self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+
+        assert backend_dtype is not None
+        self.dtype = backend_dtype
         self.inference_schema = (
             (
                 *self.args_schema,

+ 60 - 58
src/petals/server/handler.py

@@ -1,6 +1,6 @@
 import asyncio
 import contextlib
-from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
 
 import torch
 from async_timeout import timeout
@@ -202,14 +202,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             hidden_states = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
-
-            # Serialize output and respond to client
             return runtime_pb2.ExpertResponse(
-                tensors=[
-                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                    for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-                ]
+                tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
             )
 
     async def rpc_forward_stream(
@@ -230,22 +224,34 @@ class TransformerConnectionHandler(ConnectionHandler):
             hidden_states = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-            assert (
-                isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
-            ), "hidden_states must be a 3d tensor"
-
-            # Serialize the overall output
-            serialized_output = [
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
-            ]
 
             # Split the serialized_output for streaming and respond to client
-            output_split = [
-                part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-            ]
-            async for part in as_aiter(*output_split):
-                yield runtime_pb2.ExpertResponse(tensors=[part])
+            for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    def _serialize_outputs(
+        self,
+        hidden_states: torch.Tensor,
+        requested_backends: Sequence[TransformerBackend],
+        metadata: Dict[str, Any],
+    ) -> Sequence[runtime_pb2.Tensor]:
+        """Serialize forward outputs using either outputs_schema or custom user-specified schema"""
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
+        outputs_schema = requested_backends[-1].outputs_schema
+
+        if metadata.get("output_compression") is not None:
+            assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compression = tuple(metadata["output_compression"])
+            assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
+            assert len(output_compression) == 1, f"output_compression tuple should have 1 element"
+        else:
+            output_compression = tuple(tensor.compression for tensor in outputs_schema)
+
+        return [
+            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
+            for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)
+        ]
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         async with timeout(self.request_timeout):
@@ -265,21 +271,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
 
-            # Modify grad_inputs_schema to support grad_prompts
-            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-
-            grad_inputs_schema_with_prompts = (
-                requested_backends[0].args_schema * len(grads),
-                requested_backends[0].kwargs_schema,
-            )  # TODO generalize
-
-            # Serialize the overall grad_input and respond
-            return runtime_pb2.ExpertResponse(
-                tensors=[
-                    serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                    for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-                ]
-            )
+            return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
 
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -298,28 +290,38 @@ class TransformerConnectionHandler(ConnectionHandler):
             grads = await _rpc_backward(
                 *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
-
-            # Modify grad_inputs_schema to support grad_prompts
-            assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
-            grad_inputs_schema_with_prompts = (
-                requested_backends[0].args_schema * len(grads),
-                requested_backends[0].kwargs_schema,
-            )  # TODO generalize
-
-            # Serialize the overall grad_inputs
-            serialized_grad_inputs = [
-                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
-            ]
             # Split the serialized_grad_inputs for streaming and respond
-            output_split = [
-                part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
-            ]
+            for tensor in self._serialize_grads(grads, requested_backends, metadata):
+                for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
+                    yield runtime_pb2.ExpertResponse(tensors=[part])
 
-            async for part in as_aiter(*output_split):
-                yield runtime_pb2.ExpertResponse(tensors=[part])
-
-    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
+    def _serialize_grads(
+        self,
+        grads: Sequence[torch.Tensor],
+        requested_backends: Sequence[TransformerBackend],
+        metadata: Dict[str, Any],
+    ) -> Sequence[runtime_pb2.Tensor]:
+        """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+        flat_grads_schema = tuple(
+            nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
+        )  # TODO generalize
+
+        if metadata.get("output_compression") is not None:
+            assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compression = tuple(metadata["output_compression"])
+            assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
+            assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
+        else:
+            output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
+
+        return [
+            serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
+            for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
+        ]
+
+    def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:
         """Check that the first request to rpc_inference is valid"""
         uids = (uids or "").split(CHAIN_DELIMITER)
         if not uids:
@@ -360,7 +362,7 @@ class TransformerConnectionHandler(ConnectionHandler):
 
             yield handles
 
-    def _log_request(self, method: str, uids: List[ModuleUID], context: P2PContext) -> None:
+    def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
         friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
         friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
         friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids

+ 4 - 4
src/petals/server/server.py

@@ -286,27 +286,27 @@ class ModuleContainer(threading.Thread):
                 )
 
                 if load_in_8bit:
-                    dtype = block.input_layernorm.weight.dtype
                     block = replace_8bit_linear(block)
 
                 block = block.to(device)
                 for param in block.parameters():
                     param.requires_grad = False
 
+                backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
                     memory_cache=memory_cache,
-                    backend_dtype=None if torch_dtype == "auto" else torch_dtype,
+                    backend_dtype=backend_dtype,
                     args_schema=(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
                     ),
                     kwargs_schema={},
                     outputs_schema=(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
                         ),
                     ),
                     min_batch_size=min_batch_size,

+ 40 - 2
tests/test_remote_sequential.py

@@ -1,11 +1,13 @@
 import pytest
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
+from hivemind.proto import runtime_pb2
 from test_utils import *
 
 from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import RemoteSequential
+from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client.remote_model import DistributedBloomConfig
+from petals.data_structures import UID_DELIMITER
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -43,6 +45,42 @@ def test_remote_sequential():
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad)
 
+    # test RemoteSequential with lossy compression
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    lossy_sequential = RemoteSequential(
+        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+    )
+
+    test_inputs.grad = None
+    approx_outputs = lossy_sequential(test_inputs)
+    (approx_outputs * grad_proj).sum().backward()
+
+    assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
+    assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
+    assert abs(approx_outputs - full_outputs).mean() < 0.01
+    assert abs(test_inputs.grad - full_grad).mean() < 0.3
+
+
+class DummyCustomSequenceManager(RemoteSequenceManager):
+    """A sequence manager that compresses inputs/outputs during forward and backward pass."""
+
+    @property
+    def rpc_info(self):
+        rpc_info = super().rpc_info
+        dims = (2048, 1024)
+        compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
+        rpc_info["forward_schema"] = (compressed_input_schema,), dict()  # (args, kwargs)
+        return rpc_info
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs):
+        if protocol == "rpc_forward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
+        elif protocol == "rpc_backward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
+        else:
+            assert protocol == "rpc_inference"
+            return super().get_request_metadata(protocol, *args, **kwargs)
+
 
 @pytest.mark.forked
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):