|
@@ -32,7 +32,7 @@ class TensorDescriptor(DescriptorBase):
|
|
|
|
|
|
@classmethod
|
|
|
def from_tensor(cls, tensor: torch.Tensor):
|
|
|
- return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned())
|
|
|
+ return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, safe_check_pinned(tensor))
|
|
|
|
|
|
def make_empty(self, **kwargs):
|
|
|
properties = asdict(self)
|
|
@@ -53,9 +53,17 @@ class BatchTensorDescriptor(TensorDescriptor):
|
|
|
def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
|
|
|
return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
|
|
|
device=tensor.device, requires_grad=tensor.requires_grad,
|
|
|
- pin_memory=torch.cuda.is_available() and tensor.is_pinned(),
|
|
|
+ pin_memory=safe_check_pinned(tensor),
|
|
|
compression=compression if tensor.is_floating_point() else CompressionType.NONE)
|
|
|
|
|
|
def make_empty(self, batch_size, **kwargs):
|
|
|
assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
|
|
|
return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+def safe_check_pinned(tensor: torch.Tensor) -> bool:
|
|
|
+ """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
|
|
|
+ try:
|
|
|
+ return torch.cuda.is_available() and tensor.is_pinned()
|
|
|
+ except RuntimeError:
|
|
|
+ return False
|