|
@@ -46,11 +46,11 @@ class TensorDescriptor(DescriptorBase):
|
|
tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
|
|
tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
|
|
)
|
|
)
|
|
|
|
|
|
- def make_empty(self, **kwargs):
|
|
|
|
|
|
+ def make_zeros(self, **kwargs):
|
|
properties = asdict(self)
|
|
properties = asdict(self)
|
|
properties.update(kwargs)
|
|
properties.update(kwargs)
|
|
properties.pop("compression")
|
|
properties.pop("compression")
|
|
- return torch.empty(**properties)
|
|
|
|
|
|
+ return torch.zeros(**properties)
|
|
|
|
|
|
|
|
|
|
def _str_to_torch_type(name: str, torch_type: type):
|
|
def _str_to_torch_type(name: str, torch_type: type):
|
|
@@ -86,9 +86,9 @@ class BatchTensorDescriptor(TensorDescriptor):
|
|
compression=compression if tensor.is_floating_point() else CompressionType.NONE,
|
|
compression=compression if tensor.is_floating_point() else CompressionType.NONE,
|
|
)
|
|
)
|
|
|
|
|
|
- def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
|
|
|
|
|
|
+ def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
|
|
assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
|
|
assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
|
|
- return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
|
|
|
|
|
|
+ return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
|
|
|
|
|
|
def packb(self) -> bytes:
|
|
def packb(self) -> bytes:
|
|
obj_dict = asdict(self)
|
|
obj_dict = asdict(self)
|