benchmark_tensor_compression.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import argparse
  2. import time
  3. import torch
  4. from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
  5. from hivemind.proto.runtime_pb2 import CompressionType
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__name__)
  9. def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
  10. t = time.time()
  11. deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
  12. return time.time() - t
  13. if __name__ == "__main__":
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument("--size", type=int, default=10000000, required=False)
  16. parser.add_argument("--seed", type=int, default=7348, required=False)
  17. parser.add_argument("--num_iters", type=int, default=30, required=False)
  18. args = parser.parse_args()
  19. torch.manual_seed(args.seed)
  20. X = torch.randn(args.size)
  21. for name, compression_type in CompressionType.items():
  22. tm = 0
  23. for i in range(args.num_iters):
  24. tm += benchmark_compression(X, compression_type)
  25. tm /= args.num_iters
  26. logger.info(f"Compression type: {name}, time: {tm}")