justheuristic vor 2 Jahren
Ursprung
Commit
a4155a628a

+ 8 - 12
src/client/remote_forward_backward.py

@@ -2,7 +2,7 @@
 Utility functions that call RPC forward or backward on a single remote server
 """
 import asyncio
-from typing import Iterable, List, Sequence, Tuple
+from typing import Iterable, List, Sequence, Tuple, Optional
 
 import torch
 from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
@@ -63,7 +63,8 @@ async def _backward_stream(
 
 
 async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float,
+        metadata: Optional[bytes] = None, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
@@ -102,11 +103,8 @@ async def run_remote_forward(
 
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
-    if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
-    else:
-        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
-
+    forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _forward_unary
+    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -118,6 +116,7 @@ async def run_remote_backward(
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
     timeout: float,
+    metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Sequence[torch.Tensor]:
     """
@@ -146,9 +145,6 @@ async def run_remote_backward(
     )
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
-    if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
-    else:
-        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
-
+    backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE else _backward_unary
+    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
     return deserialized_grad_inputs

+ 10 - 0
src/client/sequence_manager.py

@@ -44,6 +44,7 @@ class RemoteSequenceManager:
         self.timeout, self.min_backoff = timeout, min_backoff
         self._rpc_info = None
         self.lock_changes = threading.Lock()
+        self.policy = NoSpendingPolicy()
         self.update_()
 
         for uid, info in zip(self.block_uids, self.block_infos):
@@ -165,3 +166,12 @@ class RemoteSequenceManager:
         if attempt_no == 0:
             return 0
         return self.min_backoff * 2 ** (attempt_no - 1)
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
+        """
+        :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
+        :param args: request-specific inputs, typicall block uids and input tensors
+        :param kwargs: additional request context, such as remote peer ID
+        :returns: msgpack-serialized metadata dict that will be passed alongside a given request
+        """
+        return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))

+ 7 - 1
src/client/sequential_autograd.py

@@ -72,8 +72,11 @@ async def sequential_forward(
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
+                metadata = sequence_manager.get_request_metadata(
+                    'rpc_forward', span_uids, *inputs_and_prompts)
                 (outputs,) = await run_remote_forward(
-                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
+                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts,
+                    timeout=sequence_manager.timeout, metadata=metadata
                 )
 
                 assert isinstance(outputs, torch.Tensor)
@@ -146,6 +149,8 @@ async def sequential_backward(
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                metadata = sequence_manager.get_request_metadata(
+                    'rpc_backward', span_uids, *inputs, *grad_outputs, peer_id=span.peer_id)
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
                     stub,
@@ -154,6 +159,7 @@ async def sequential_backward(
                     grad_outputs,
                     prompts[span.start : span.end],
                     timeout=sequence_manager.timeout,
+                    metadata=metadata,
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)

+ 7 - 4
src/client/spending_policy.py

@@ -1,14 +1,17 @@
+"""
+An interface for exchanging internal "BLOOM points" for higher priority compute requests. NOT IMPLEMENTED.
+The intent is to let Petals participants earn points by helping others while idle (e.g. at night), then use these
+ points to run their own compute experiments faster. See Section 4 of https://arxiv.org/abs/2209.01188 for discussion.
+"""
 from abc import ABC, abstractmethod
 
-from hivemind.proto.runtime_pb2 import ExpertRequest
-
 
 class SpendingPolicyBase(ABC):
     @abstractmethod
-    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+    def get_points(self, protocol: str, *args, **kwargs) -> float:
         pass
 
 
 class NoSpendingPolicy(SpendingPolicyBase):
-    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+    def get_points(self, protocol: str, *args, **kwargs) -> float:
         return 0.0

+ 14 - 14
src/server/handler.py

@@ -240,17 +240,17 @@ class TransformerConnectionHandler(ConnectionHandler):
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
         outputs_schema = requested_backends[-1].outputs_schema
 
-        if metadata.get("output_compressions") is not None:
-            assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
-            output_compressions = tuple(metadata["output_compressions"])
-            assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
-            assert len(output_compressions) == 1, f"output_compression tuple should have 1 element"
+        if metadata.get("output_compression") is not None:
+            assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compression = tuple(metadata["output_compression"])
+            assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
+            assert len(output_compression) == 1, f"output_compression tuple should have 1 element"
         else:
-            output_compressions = tuple(tensor.compression for tensor in outputs_schema)
+            output_compression = tuple(tensor.compression for tensor in outputs_schema)
 
         return [
             serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
-            for result, proto, compression in zip([hidden_states], outputs_schema, output_compressions)
+            for result, proto, compression in zip([hidden_states], outputs_schema, output_compression)
         ]
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
@@ -308,17 +308,17 @@ class TransformerConnectionHandler(ConnectionHandler):
             nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
         )  # TODO generalize
 
-        if metadata.get("output_compressions") is not None:
-            assert isinstance(metadata["output_compressions"], (list, tuple)), "output_compression must be a tuple/list"
-            output_compressions = tuple(metadata["output_compressions"])
-            assert all(isinstance(c, int) for c in output_compressions), "output_compression must contain integers"
-            assert len(output_compressions) == len(grads), f"output_compression should have {len(grads)} elements"
+        if metadata.get("output_compression") is not None:
+            assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list"
+            output_compression = tuple(metadata["output_compression"])
+            assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers"
+            assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements"
         else:
-            output_compressions = tuple(tensor.compression for tensor in flat_grads_schema)
+            output_compression = tuple(tensor.compression for tensor in flat_grads_schema)
 
         return [
             serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True)
-            for result, proto, compression in zip(grads, flat_grads_schema, output_compressions)
+            for result, proto, compression in zip(grads, flat_grads_schema, output_compression)
         ]
 
     def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]:

+ 40 - 2
tests/test_remote_sequential.py

@@ -1,11 +1,13 @@
 import pytest
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
+from hivemind.proto import runtime_pb2
 from test_utils import *
 
-from src import RemoteSequential
+from src import RemoteSequenceManager, RemoteSequential
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_model import DistributedBloomConfig
+from src.data_structures import UID_DELIMITER
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -43,6 +45,42 @@ def test_remote_sequential():
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad)
 
+    # test RemoteSequential with lossy compression
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    lossy_sequential = RemoteSequential(
+        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+    )
+
+    test_inputs.grad = None
+    approx_outputs = lossy_sequential(test_inputs)
+    (approx_outputs * grad_proj).sum().backward()
+
+    assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
+    assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-4), "compression was not used"
+    assert abs(approx_outputs - full_outputs).mean() < 0.01
+    assert abs(test_inputs.grad - full_grad).mean() < 0.1
+
+
+class DummyCustomSequenceManager(RemoteSequenceManager):
+    """A sequence manager that compresses inputs/outputs during forward and backward pass."""
+
+    @property
+    def rpc_info(self):
+        rpc_info = super().rpc_info
+        dims = (2048, 1024)
+        compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
+        rpc_info["forward_schema"] = (compressed_input_schema,), dict()  # (args, kwargs)
+        return rpc_info
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs):
+        if protocol == "rpc_forward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
+        elif protocol == "rpc_backward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
+        else:
+            assert protocol == "rpc_inference"
+            return super().get_request_metadata(protocol, *args, **kwargs)
+
 
 @pytest.mark.forked
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):