|
@@ -241,7 +241,8 @@ class TaskPool(TaskPoolBase):
|
|
|
def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[torch.Tensor]):
|
|
|
"""send results for a processed batch, previously loaded through load_batch_to_runtime"""
|
|
|
batch_outputs = [
|
|
|
- tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
|
|
|
+ tensor.to(device="cpu", non_blocking=False).share_memory_().detach().requires_grad_(tensor.requires_grad)
|
|
|
+ # note: tensor.to deliberately does NOT use non_blocking; non_blocking + share_memory = undefined behavior
|
|
|
for tensor in batch_outputs
|
|
|
]
|
|
|
self.outputs_sender.send((batch_index, batch_outputs))
|