adaptive.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from abc import ABC, abstractmethod
  2. from typing import Mapping, Sequence, Union
  3. import torch
  4. from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole
  5. from hivemind.compression.serialization import deserialize_torch_tensor
  6. from hivemind.proto import runtime_pb2
  7. class AdaptiveCompressionBase(CompressionBase, ABC):
  8. @abstractmethod
  9. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  10. ...
  11. def estimate_compression_ratio(self, info: CompressionInfo) -> float:
  12. return self.choose_compression(info).estimate_compression_ratio(info)
  13. def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
  14. return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace)
  15. def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
  16. return deserialize_torch_tensor(serialized_tensor)
  17. class SizeAdaptiveCompression(AdaptiveCompressionBase):
  18. """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise"""
  19. def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase):
  20. self.threshold, self.less, self.greater_equal = threshold, less, greater_equal
  21. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  22. return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less
  23. class RoleAdaptiveCompression(AdaptiveCompressionBase):
  24. """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option"""
  25. def __init__(
  26. self,
  27. *,
  28. activation: CompressionBase = None,
  29. parameter: CompressionBase = None,
  30. gradient: CompressionBase = None,
  31. optimizer: CompressionBase = None,
  32. default: CompressionBase = NoCompression()
  33. ):
  34. self.role_compressions = {
  35. TensorRole.ACTIVATION: activation or default,
  36. TensorRole.PARAMETER: parameter or default,
  37. TensorRole.GRADIENT: gradient or default,
  38. TensorRole.OPTIMIZER: optimizer or default,
  39. TensorRole.UNSPECIFIED: default,
  40. }
  41. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  42. return self.role_compressions[info.role]
  43. class PerTensorCompression(AdaptiveCompressionBase):
  44. """Manually specify the compression strategy depending on tensor key"""
  45. def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]):
  46. self.tensor_compressions = tensor_compressions
  47. def choose_compression(self, info: CompressionInfo) -> CompressionBase:
  48. return self.tensor_compressions[info.key]