benchmark_tensor_compression.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import argparse
  2. import time
  3. import torch
  4. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  5. from hivemind.proto.runtime_pb2 import CompressionType
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__name__)
  9. def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]:
  10. t = time.time()
  11. serialized = serialize_torch_tensor(tensor, compression_type)
  12. result = deserialize_torch_tensor(serialized)
  13. return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
  14. if __name__ == "__main__":
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument("--size", type=int, default=10_000_000, required=False)
  17. parser.add_argument("--seed", type=int, default=7348, required=False)
  18. parser.add_argument("--num_iters", type=int, default=30, required=False)
  19. args = parser.parse_args()
  20. torch.manual_seed(args.seed)
  21. X = torch.randn(args.size, dtype=torch.float32)
  22. for name, compression_type in CompressionType.items():
  23. total_time = 0
  24. compression_error = 0
  25. total_size = 0
  26. for i in range(args.num_iters):
  27. iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
  28. total_time += iter_time
  29. compression_error += iter_distortion
  30. total_size += size
  31. total_time /= args.num_iters
  32. compression_error /= args.num_iters
  33. total_size /= args.num_iters
  34. logger.info(
  35. f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, "
  36. f"size: {int(total_size):d}"
  37. )