base.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from enum import Enum, auto
  4. from typing import Any, Optional
  5. import numpy as np
  6. import torch
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils.tensor_descr import TensorDescriptor
  9. Key = Any
  10. class TensorRole(Enum):
  11. ACTIVATION = auto()
  12. PARAMETER = auto()
  13. GRADIENT = auto()
  14. OPTIMIZER = auto()
  15. UNSPECIFIED = auto()
  16. @dataclasses.dataclass(frozen=True)
  17. class CompressionInfo:
  18. """Auxiliary data structure that contains information about the tensor that determines how it is compressed"""
  19. key: Key # name or index of the tensor from named parameters, optimizer state dict or i/o structure
  20. descriptor: TensorDescriptor # data structure that defines shape, dtype, layout and device information
  21. role: TensorRole = TensorRole.UNSPECIFIED # which role does the tensor play with respect to the model
  22. part_index: int = 0 # if tensor is sliced into parts, this represents the index within one tensor
  23. part_size: Optional[int] = None # if tensor is sliced into parts, this is the _maximum_ number of values per part
  24. @classmethod
  25. def from_tensor(cls, tensor: torch.Tensor, key: Key = None, descriptor: TensorDescriptor = None, **kwargs):
  26. return cls(key, descriptor or TensorDescriptor.from_tensor(tensor), **kwargs)
  27. def get_part(self, part_index: int, part_size: Optional[int]):
  28. return CompressionInfo(self.key, self.descriptor, self.role, part_index=part_index, part_size=part_size)
  29. class CompressionBase(ABC):
  30. """A base class that applies compression algorithm to a pytorch tensor"""
  31. compression_type: runtime_pb2.CompressionType
  32. @abstractmethod
  33. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  34. """
  35. Applies compression algorithm to a tensor based on their meta-parameters
  36. :param tensor: a pytorch tensor to compress; depending on the applicaiton, it is a full tensor or a part
  37. :param info: meta-information about the tensor; if partitioning is used, this still describes the full tensor
  38. :param allow_inplace: if True, compression can (but doesn't have to) to modify tensor in-place for efficiency
  39. :returns: a protobuf message that encodes the tensor
  40. """
  41. ...
  42. @abstractmethod
  43. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  44. """Create a pytorch tensor from the serialized outputs of .compress"""
  45. ...
  46. @abstractmethod
  47. def estimate_compression_ratio(self, info: CompressionInfo) -> float:
  48. """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
  49. ...
  50. def __repr__(self):
  51. return f"hivemind.{self.__class__.__name__}()"
  52. class NoCompression(CompressionBase):
  53. """A dummy compression strategy that preserves the original tensor as is."""
  54. compression_type = runtime_pb2.CompressionType.NONE
  55. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  56. array = tensor.numpy()
  57. return runtime_pb2.Tensor(
  58. compression=self.compression_type,
  59. buffer=array.tobytes(),
  60. size=array.shape,
  61. dtype=array.dtype.name,
  62. requires_grad=tensor.requires_grad,
  63. )
  64. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  65. array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
  66. return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
  67. def estimate_compression_ratio(self, info: CompressionInfo) -> float:
  68. return 1.0