tensor_descr.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import warnings
  2. from dataclasses import dataclass, asdict
  3. import torch
  4. from hivemind.proto.runtime_pb2 import CompressionType
  5. DUMMY_BATCH_SIZE = 3 # used for dummy runs only
  6. warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
  7. # ^-- cures https://github.com/pytorch/pytorch/issues/47038
  8. @dataclass(init=True, repr=True, frozen=True)
  9. class DescriptorBase:
  10. pass
  11. @dataclass(init=True, repr=True, frozen=True)
  12. class TensorDescriptor(DescriptorBase):
  13. size: tuple
  14. dtype: torch.dtype = None
  15. layout: torch.layout = torch.strided
  16. device: torch.device = None
  17. requires_grad: bool = False
  18. pin_memory: bool = False
  19. compression: CompressionType = CompressionType.NONE
  20. @property
  21. def shape(self):
  22. return self.size
  23. @classmethod
  24. def from_tensor(cls, tensor: torch.Tensor):
  25. return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad,
  26. safe_check_pinned(tensor))
  27. def make_empty(self, **kwargs):
  28. properties = asdict(self)
  29. properties.update(kwargs)
  30. properties.pop('compression')
  31. return torch.empty(**properties)
  32. @dataclass(repr=True, frozen=True)
  33. class BatchTensorDescriptor(TensorDescriptor):
  34. """ torch Tensor with a variable 0-th dimension, used to describe batched data """
  35. def __init__(self, *instance_size, **kwargs): # compatibility: allow initializing with *size
  36. if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
  37. instance_size = instance_size[0] # we were given size as the only parameter instead of *parameters
  38. super().__init__((None, *instance_size), **kwargs)
  39. @classmethod
  40. def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
  41. return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
  42. device=tensor.device, requires_grad=tensor.requires_grad,
  43. pin_memory=safe_check_pinned(tensor),
  44. compression=compression if tensor.is_floating_point() else CompressionType.NONE)
  45. def make_empty(self, batch_size, **kwargs):
  46. assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
  47. return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
  48. def safe_check_pinned(tensor: torch.Tensor) -> bool:
  49. """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
  50. try:
  51. return torch.cuda.is_available() and tensor.is_pinned()
  52. except RuntimeError:
  53. return False