justheuristic 2 yıl önce
ebeveyn
işleme
074947aff5
1 değiştirilmiş dosya ile 41 ekleme ve 2 silme
  1. 41 2
      tests/test_remote_sequential.py

+ 41 - 2
tests/test_remote_sequential.py

@@ -1,10 +1,13 @@
 import pytest
 import torch
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import DHT, get_logger, use_hivemind_log_handler, BatchTensorDescriptor, MSGPackSerializer
+from hivemind.proto import runtime_pb2
+
+from petals.data_structures import UID_DELIMITER
 from test_utils import *
 
 from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import NoSpendingPolicy, RemoteSequential
+from petals.client import RemoteSequential, RemoteSequenceManager
 from petals.client.remote_model import DistributedBloomConfig
 
 use_hivemind_log_handler("in_root_logger")
@@ -43,6 +46,42 @@ def test_remote_sequential():
     (second_half_outputs * grad_proj).sum().backward()
     assert torch.allclose(test_inputs.grad, full_grad)
 
+    # test RemoteSequential with lossy compression
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    lossy_sequential = RemoteSequential(
+        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+    )
+
+    test_inputs.grad = None
+    approx_outputs = lossy_sequential(test_inputs)
+    (approx_outputs * grad_proj).sum().backward()
+
+    assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
+    assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
+    assert abs(approx_outputs - full_outputs).mean() < 0.01
+    assert abs(test_inputs.grad - full_grad).mean() < 0.3
+
+
+class DummyCustomSequenceManager(RemoteSequenceManager):
+    """A sequence manager that compresses inputs/outputs during forward and backward pass."""
+
+    @property
+    def rpc_info(self):
+        rpc_info = super().rpc_info
+        dims = (2048, 1024)
+        compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
+        rpc_info["forward_schema"] = (compressed_input_schema,), dict()  # (args, kwargs)
+        return rpc_info
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs):
+        if protocol == "rpc_forward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
+        elif protocol == "rpc_backward":
+            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
+        else:
+            assert protocol == "rpc_inference"
+            return super().get_request_metadata(protocol, *args, **kwargs)
+
 
 @pytest.mark.forked
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):