tensor_descr.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from dataclasses import dataclass, asdict
  2. import torch
  3. from hivemind.proto.runtime_pb2 import CompressionType
  4. DUMMY_BATCH_SIZE = 3 # used for dummy runs only
  5. @dataclass(init=True, repr=True, frozen=True)
  6. class DescriptorBase:
  7. pass
  8. @dataclass(init=True, repr=True, frozen=True)
  9. class TensorDescriptor(DescriptorBase):
  10. size: tuple
  11. dtype: torch.dtype = None
  12. layout: torch.layout = torch.strided
  13. device: torch.device = None
  14. requires_grad: bool = False
  15. pin_memory: bool = False
  16. compression: CompressionType = CompressionType.NONE
  17. @property
  18. def shape(self):
  19. return self.size
  20. @classmethod
  21. def from_tensor(cls, tensor: torch.Tensor):
  22. return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned())
  23. def make_empty(self, **kwargs):
  24. properties = asdict(self)
  25. properties.update(kwargs)
  26. return torch.empty(**properties)
  27. @dataclass(repr=True, frozen=True)
  28. class BatchTensorDescriptor(TensorDescriptor):
  29. """ torch Tensor with a variable 0-th dimension, used to describe batched data """
  30. def __init__(self, *instance_size, **kwargs): # compatibility: allow initializing with *size
  31. if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
  32. instance_size = instance_size[0] # we were given size as the only parameter instead of *parameters
  33. super().__init__((None, *instance_size), **kwargs)
  34. @classmethod
  35. def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
  36. return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
  37. device=tensor.device, requires_grad=tensor.requires_grad,
  38. pin_memory=torch.cuda.is_available() and tensor.is_pinned(),
  39. compression=compression if tensor.is_floating_point() else CompressionType.NONE)
  40. def make_empty(self, batch_size, **kwargs):
  41. assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
  42. return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)