소스 검색

Change make_empty to make_zeros for TensorDescriptor (#442)

Right now, make_empty will fail when trying to make a dummy forward pass with an expert that accepts LongTensor objects as an input (e.g., embeddings), since usually not all integer values are valid inputs. Since on the client side we actually initialize those empty tensors with zeros too, the simplest solution appears to be just making zero tensors in each case by default.
Max Ryabinin 3 년 전
부모
커밋
de0bc00a3f
3개의 변경된 파일8개의 추가작업 그리고 8개의 파일을 삭제
  1. 2 2
      hivemind/moe/client/moe.py
  2. 2 2
      hivemind/moe/server/expert_backend.py
  3. 4 4
      hivemind/utils/tensor_descr.py

+ 2 - 2
hivemind/moe/client/moe.py

@@ -245,7 +245,7 @@ class _RemoteCallMany(torch.autograd.Function):
         else:
             outputs_schema = info["outputs_schema"]
         outputs = nested_map(
-            lambda descriptor: descriptor.make_empty(num_samples, max_experts, device=flat_inputs[0].device).zero_(),
+            lambda descriptor: descriptor.make_zeros(num_samples, max_experts, device=flat_inputs[0].device),
             outputs_schema,
         )
 
@@ -341,7 +341,7 @@ class _RemoteCallMany(torch.autograd.Function):
         # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
 
         grad_inputs = nested_map(
-            lambda descr: descr.make_empty(num_samples, device=flat_grad_outputs[0].device).zero_(),
+            lambda descr: descr.make_zeros(num_samples, device=flat_grad_outputs[0].device),
             list(nested_flatten(info["forward_schema"])),
         )
 

+ 2 - 2
hivemind/moe/server/expert_backend.py

@@ -74,8 +74,8 @@ class ExpertBackend:
 
         if outputs_schema is None:
             # run expert once to get outputs schema
-            dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
-            dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
+            dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
+            dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 

+ 4 - 4
hivemind/utils/tensor_descr.py

@@ -46,11 +46,11 @@ class TensorDescriptor(DescriptorBase):
             tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor)
         )
 
-    def make_empty(self, **kwargs):
+    def make_zeros(self, **kwargs):
         properties = asdict(self)
         properties.update(kwargs)
         properties.pop("compression")
-        return torch.empty(**properties)
+        return torch.zeros(**properties)
 
 
 def _str_to_torch_type(name: str, torch_type: type):
@@ -86,9 +86,9 @@ class BatchTensorDescriptor(TensorDescriptor):
             compression=compression if tensor.is_floating_point() else CompressionType.NONE,
         )
 
-    def make_empty(self, *batch_size: int, **kwargs) -> torch.Tensor:
+    def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor:
         assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)"
-        return super().make_empty(size=(*batch_size, *self.shape[1:]), **kwargs)
+        return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs)
 
     def packb(self) -> bytes:
         obj_dict = asdict(self)