grpc.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """
  2. Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
  3. """
  4. import numpy as np
  5. import torch
  6. from hivemind.proto import runtime_pb2
  7. from hivemind.proto.runtime_pb2 import CompressionType
  8. FP16_MAX = 65_504
  9. def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionType.NONE,
  10. allow_inplace=False) -> runtime_pb2.Tensor:
  11. if compression_type == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
  12. assert tensor.dtype == torch.float32
  13. tensor = tensor if allow_inplace else tensor.clone()
  14. means = torch.mean(tensor, dim=-1, keepdim=True)
  15. tensor.sub_(means)
  16. stds = torch.square(tensor).sum(dim=-1, keepdim=True).div_(tensor.shape[-1]).sqrt_()
  17. tensor.div_(stds)
  18. tensor = tensor.clamp_(-FP16_MAX, FP16_MAX).to(torch.float16)
  19. data = b''.join((tensor.numpy().tobytes(), means.numpy().tobytes(), stds.numpy().tobytes()))
  20. proto = runtime_pb2.Tensor(
  21. compression=compression_type,
  22. buffer=data,
  23. size=tensor.shape,
  24. dtype='compressed_float32',
  25. requires_grad=tensor.requires_grad)
  26. else:
  27. array = tensor.numpy()
  28. proto = runtime_pb2.Tensor(
  29. compression=compression_type,
  30. buffer=array.tobytes(),
  31. size=array.shape,
  32. dtype=array.dtype.name,
  33. requires_grad=tensor.requires_grad)
  34. return proto
  35. def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  36. # TODO avoid copying the array (need to silence pytorch warning, because array is not writable)
  37. if serialized_tensor.compression == CompressionType.NONE:
  38. array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy()
  39. tensor = torch.as_tensor(array).view(*serialized_tensor.size).requires_grad_(serialized_tensor.requires_grad)
  40. elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16:
  41. stats_size = list(serialized_tensor.size)
  42. stats_size[-1] = 1
  43. stats_count = np.prod(stats_size)
  44. means, stds = serialized_tensor.buffer[-8*stats_count:-4*stats_count], serialized_tensor.buffer[-4*stats_count:]
  45. means = torch.as_tensor(np.frombuffer(means, dtype=np.float32)).view(*stats_size)
  46. stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32)).view(*stats_size)
  47. array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16)
  48. tensor = torch.as_tensor(array).to(torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means)
  49. else:
  50. raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")
  51. return tensor