|
@@ -1,10 +1,13 @@
|
|
|
import pytest
|
|
|
import torch
|
|
|
-from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
+from hivemind import DHT, get_logger, use_hivemind_log_handler, BatchTensorDescriptor, MSGPackSerializer
|
|
|
+from hivemind.proto import runtime_pb2
|
|
|
+
|
|
|
+from petals.data_structures import UID_DELIMITER
|
|
|
from test_utils import *
|
|
|
|
|
|
from petals.bloom.from_pretrained import load_pretrained_block
|
|
|
-from petals.client import NoSpendingPolicy, RemoteSequential
|
|
|
+from petals.client import RemoteSequential, RemoteSequenceManager
|
|
|
from petals.client.remote_model import DistributedBloomConfig
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
@@ -43,6 +46,42 @@ def test_remote_sequential():
|
|
|
(second_half_outputs * grad_proj).sum().backward()
|
|
|
assert torch.allclose(test_inputs.grad, full_grad)
|
|
|
|
|
|
+ # test RemoteSequential with lossy compression
|
|
|
+ block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
|
|
|
+ lossy_sequential = RemoteSequential(
|
|
|
+ config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
|
|
|
+ )
|
|
|
+
|
|
|
+ test_inputs.grad = None
|
|
|
+ approx_outputs = lossy_sequential(test_inputs)
|
|
|
+ (approx_outputs * grad_proj).sum().backward()
|
|
|
+
|
|
|
+ assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
|
|
|
+ assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
|
|
|
+ assert abs(approx_outputs - full_outputs).mean() < 0.01
|
|
|
+ assert abs(test_inputs.grad - full_grad).mean() < 0.3
|
|
|
+
|
|
|
+
|
|
|
+class DummyCustomSequenceManager(RemoteSequenceManager):
|
|
|
+ """A sequence manager that compresses inputs/outputs during forward and backward pass."""
|
|
|
+
|
|
|
+ @property
|
|
|
+ def rpc_info(self):
|
|
|
+ rpc_info = super().rpc_info
|
|
|
+ dims = (2048, 1024)
|
|
|
+ compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
|
|
|
+ rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs)
|
|
|
+ return rpc_info
|
|
|
+
|
|
|
+ def get_request_metadata(self, protocol: str, *args, **kwargs):
|
|
|
+ if protocol == "rpc_forward":
|
|
|
+ return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
|
|
|
+ elif protocol == "rpc_backward":
|
|
|
+ return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
|
|
|
+ else:
|
|
|
+ assert protocol == "rpc_inference"
|
|
|
+ return super().get_request_metadata(protocol, *args, **kwargs)
|
|
|
+
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
|