浏览代码

handle pin_memory to avoid runtimeerror (#139)

justheuristic 4 年之前
父节点
当前提交
62e674195a
共有 2 个文件被更改,包括 11 次插入3 次删除
  1. 1 1
      hivemind/__init__.py
  2. 10 2
      hivemind/utils/tensor_descr.py

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.24'
+__version__ = '0.8.25'

+ 10 - 2
hivemind/utils/tensor_descr.py

@@ -32,7 +32,7 @@ class TensorDescriptor(DescriptorBase):
 
     @classmethod
     def from_tensor(cls, tensor: torch.Tensor):
-        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned())
+        return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, safe_check_pinned(tensor))
 
     def make_empty(self, **kwargs):
         properties = asdict(self)
@@ -53,9 +53,17 @@ class BatchTensorDescriptor(TensorDescriptor):
     def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE):
         return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout,
                    device=tensor.device, requires_grad=tensor.requires_grad,
-                   pin_memory=torch.cuda.is_available() and tensor.is_pinned(),
+                   pin_memory=safe_check_pinned(tensor),
                    compression=compression if tensor.is_floating_point() else CompressionType.NONE)
 
     def make_empty(self, batch_size, **kwargs):
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
         return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs)
+
+    
+def safe_check_pinned(tensor: torch.Tensor) -> bool:
+    """ Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error. """
+    try:
+        return torch.cuda.is_available() and tensor.is_pinned()
+    except RuntimeError:
+        return False