|
@@ -179,7 +179,7 @@ class TaskPool(TaskPoolBase):
|
|
|
# save batch futures, _output_loop will deliver on them later
|
|
|
pending_batches[batch_index] = batch_tasks
|
|
|
|
|
|
- logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
|
|
|
+ logger.debug(f"{self.name}, batch {batch_index}: aggregating inputs")
|
|
|
# find or create shared arrays for current batch size
|
|
|
batch_inputs = [torch.cat([task.args[i] for task in batch_tasks]) for i in range(len(batch_tasks[0].args))]
|
|
|
batch_inputs = [inp.detach().requires_grad_(inp.requires_grad).share_memory_() for inp in batch_inputs]
|