benchmark_tensor_compression.py 1.0 KB

1234567891011121314151617181920212223242526272829303132
  1. import argparse
  2. import time
  3. import torch
  4. from hivemind.proto.runtime_pb2 import CompressionType
  5. from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
  6. def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
  7. t = time.time()
  8. deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
  9. return time.time() - t
  10. if __name__ == "__main__":
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--size', type=int, default=10000000, required=False)
  13. parser.add_argument('--seed', type=int, default=7348, required=False)
  14. parser.add_argument('--num_iters', type=int, default=30, required=False)
  15. args = parser.parse_args()
  16. torch.manual_seed(args.seed)
  17. X = torch.randn(args.size)
  18. for name, compression_type in CompressionType.items():
  19. tm = 0
  20. for i in range(args.num_iters):
  21. tm += benchmark_compression(X, compression_type)
  22. tm /= args.num_iters
  23. print(f"Compression type: {name}, time: {tm}")