serialization.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from __future__ import annotations
  2. from typing import AsyncIterator, Dict, Iterable, List, Optional
  3. import torch
  4. from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression
  5. from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
  6. from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils.streaming import combine_from_streaming
  9. BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
  10. NONE=NoCompression(),
  11. FLOAT16=Float16Compression(),
  12. MEANSTD_16BIT=ScaledFloat16Compression(),
  13. QUANTILE_8BIT=Quantile8BitQuantization(),
  14. UNIFORM_8BIT=Uniform8BitQuantization(),
  15. )
  16. for key in runtime_pb2.CompressionType.keys():
  17. assert key in BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer"
  18. actual_compression_type = BASE_COMPRESSION_TYPES[key].compression_type
  19. assert (
  20. runtime_pb2.CompressionType.Name(actual_compression_type) == key
  21. ), f"Compression strategy for {key} has inconsistent type"
  22. def serialize_torch_tensor(
  23. tensor: torch.Tensor,
  24. compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
  25. info: Optional[CompressionInfo] = None,
  26. allow_inplace: bool = False,
  27. **kwargs,
  28. ) -> runtime_pb2.Tensor:
  29. """Serialize a given tensor into a protobuf message using the specified compression strategy"""
  30. assert tensor.device == torch.device("cpu")
  31. compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)]
  32. info = info or CompressionInfo.from_tensor(tensor, **kwargs)
  33. return compression.compress(tensor, info, allow_inplace)
  34. def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  35. """Restore a pytorch tensor from a protobuf message"""
  36. compression = BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)]
  37. return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad)
  38. async def deserialize_tensor_stream(
  39. stream: AsyncIterator[Iterable[runtime_pb2.Tensor]],
  40. ) -> List[torch.Tensor]:
  41. """Async wrapper of combine_from_streaming that combines tensors from a stream of parts and deserializes them"""
  42. tensors = []
  43. tensor_parts = []
  44. async for parts in stream:
  45. for part in parts:
  46. if part.dtype and tensor_parts:
  47. tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
  48. tensor_parts = []
  49. tensor_parts.append(part)
  50. if tensor_parts:
  51. tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts)))
  52. return tensors