Ver Fonte

Add comment warning for non_blocking + share_memory (#541)

justheuristic há 2 anos atrás
pai
commit
41a2d06048
1 ficheiros alterados com 2 adições e 1 exclusões
  1. 2 1
      hivemind/moe/server/task_pool.py

+ 2 - 1
hivemind/moe/server/task_pool.py

@@ -241,7 +241,8 @@ class TaskPool(TaskPoolBase):
     def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
         """send results for a processed batch, previously loaded through load_batch_to_runtime"""
         batch_outputs = [
-            tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
+            tensor.to(device="cpu", non_blocking=False).share_memory_().detach().requires_grad_(tensor.requires_grad)
+            # note: tensor.to deliberately does NOT use non_blocking; non_blocking + share_memory = undefined behavior
             for tensor in batch_outputs
         ]
         self.outputs_sender.send((batch_index, batch_outputs))