proto.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from dataclasses import dataclass, asdict
  2. import torch
  3. DUMMY_BATCH_SIZE = 3 # used for dummy runs only
  4. @dataclass(init=True, repr=True, frozen=True)
  5. class ProtoBase:
  6. pass
  7. @dataclass(init=True, repr=True, frozen=True)
  8. class TensorProto(ProtoBase):
  9. size: tuple
  10. dtype: torch.dtype = None
  11. layout: torch.layout = torch.strided
  12. device: torch.device = None
  13. requires_grad: bool = False
  14. pin_memory: bool = False
  15. @property
  16. def shape(self):
  17. return self.size
  18. @classmethod
  19. def from_tensor(cls, tensor: torch.Tensor):
  20. return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned())
  21. def make_empty(self, **kwargs):
  22. properties = asdict(self)
  23. properties.update(kwargs)
  24. return torch.empty(**properties)
  25. @dataclass(repr=True, frozen=True)
  26. class BatchTensorProto(TensorProto):
  27. """ torch Tensor with a variable 0-th dimension, used to describe batched data """
  28. def __init__(self, *instance_size, **kwargs): # compatibility: allow initializing with *size
  29. if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)):
  30. instance_size = instance_size[0] # we were given size as the only parameter instead of *parameters
  31. super().__init__((None, *instance_size), **kwargs)
  32. @classmethod
  33. def from_tensor(cls, tensor: torch.Tensor):
  34. return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
  35. device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned())
  36. def make_empty(self, batch_size, **kwargs):
  37. assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
  38. return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)