|
@@ -14,8 +14,9 @@ from typing import List, Tuple, Dict, Any, Generator
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
|
|
-from ..utils import SharedFuture
|
|
|
|
|
|
+from hivemind.utils import SharedFuture, get_logger
|
|
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
Task = namedtuple("Task", ("future", "args"))
|
|
Task = namedtuple("Task", ("future", "args"))
|
|
|
|
|
|
|
|
|
|
@@ -78,7 +79,6 @@ class TaskPool(TaskPoolBase):
|
|
|
|
|
|
# interaction with Runtime
|
|
# interaction with Runtime
|
|
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
|
|
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv arrays that contain batch inputs
|
|
- self.batch_received = mp.Event() # runtime can notify pool that it can send next batch
|
|
|
|
self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
|
|
self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv arrays that contain outputs
|
|
|
|
|
|
if start:
|
|
if start:
|
|
@@ -107,12 +107,11 @@ class TaskPool(TaskPoolBase):
|
|
batch = []
|
|
batch = []
|
|
total_size = 0
|
|
total_size = 0
|
|
try:
|
|
try:
|
|
|
|
+ logger.debug(f"{self.uid} getting next task")
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
task = self.tasks.get(timeout=self.timeout)
|
|
except Empty:
|
|
except Empty:
|
|
- exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
|
|
|
|
- for task in batch:
|
|
|
|
- task.future.set_exception(exc)
|
|
|
|
- raise exc
|
|
|
|
|
|
+ logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
|
|
|
|
+ continue
|
|
|
|
|
|
task_size = self.get_task_size(task)
|
|
task_size = self.get_task_size(task)
|
|
|
|
|
|
@@ -126,10 +125,10 @@ class TaskPool(TaskPoolBase):
|
|
total_size += task_size
|
|
total_size += task_size
|
|
|
|
|
|
def run(self, *args, **kwargs):
|
|
def run(self, *args, **kwargs):
|
|
- print(f'Starting pool, pid={os.getpid()}')
|
|
|
|
|
|
+ logger.info(f'{self.uid} starting, pid={os.getpid()}')
|
|
pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
|
|
pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
|
|
output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
|
|
output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
|
|
- name=f'{self.uid}-pool_output_loop')
|
|
|
|
|
|
+ name=f'{self.uid}_output')
|
|
try:
|
|
try:
|
|
output_thread.start()
|
|
output_thread.start()
|
|
self._pool_input_loop(pending_batches, *args, **kwargs)
|
|
self._pool_input_loop(pending_batches, *args, **kwargs)
|
|
@@ -144,11 +143,8 @@ class TaskPool(TaskPoolBase):
|
|
prev_num_tasks = 0 # number of tasks currently in shared buffer
|
|
prev_num_tasks = 0 # number of tasks currently in shared buffer
|
|
batch_index = max(pending_batches.keys(), default=0)
|
|
batch_index = max(pending_batches.keys(), default=0)
|
|
batch_iterator = self.iterate_minibatches(*args, **kwargs)
|
|
batch_iterator = self.iterate_minibatches(*args, **kwargs)
|
|
- self.batch_received.set() # initial state: no batches/outputs pending
|
|
|
|
|
|
|
|
while True:
|
|
while True:
|
|
- self.batch_received.wait() # wait for runtime to receive (copy) previous batch
|
|
|
|
-
|
|
|
|
# SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
|
|
# SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
|
|
# assumes that tasks are processed in the same order as they are created
|
|
# assumes that tasks are processed in the same order as they are created
|
|
for skip_i in range(prev_num_tasks):
|
|
for skip_i in range(prev_num_tasks):
|
|
@@ -156,18 +152,21 @@ class TaskPool(TaskPoolBase):
|
|
if skip_i == prev_num_tasks - 1:
|
|
if skip_i == prev_num_tasks - 1:
|
|
self.priority = finished_task_timestamp
|
|
self.priority = finished_task_timestamp
|
|
|
|
|
|
|
|
+ logger.debug(f"{self.uid} getting next batch")
|
|
batch_tasks = next(batch_iterator)
|
|
batch_tasks = next(batch_iterator)
|
|
# save batch futures, _output_loop will deliver on them later
|
|
# save batch futures, _output_loop will deliver on them later
|
|
pending_batches[batch_index] = batch_tasks
|
|
pending_batches[batch_index] = batch_tasks
|
|
|
|
|
|
|
|
+ logger.debug(f"{self.uid}, batch {batch_index}: aggregating inputs")
|
|
# find or create shared arrays for current batch size
|
|
# find or create shared arrays for current batch size
|
|
batch_inputs = [
|
|
batch_inputs = [
|
|
torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
|
|
torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
|
|
for i in range(len(batch_tasks[0].args))
|
|
for i in range(len(batch_tasks[0].args))
|
|
]
|
|
]
|
|
|
|
|
|
- self.batch_received.clear() # sending next batch...
|
|
|
|
|
|
+ logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
|
|
self.batch_sender.send((batch_index, batch_inputs))
|
|
self.batch_sender.send((batch_index, batch_inputs))
|
|
|
|
+ logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
|
|
prev_num_tasks = len(batch_tasks)
|
|
prev_num_tasks = len(batch_tasks)
|
|
batch_index += 1
|
|
batch_index += 1
|
|
|
|
|
|
@@ -175,16 +174,19 @@ class TaskPool(TaskPoolBase):
|
|
""" Infinite loop: receive results from runtime and dispatch them to task Futures """
|
|
""" Infinite loop: receive results from runtime and dispatch them to task Futures """
|
|
|
|
|
|
while True:
|
|
while True:
|
|
|
|
+ logger.debug(f"{self.uid} waiting for results from runtime")
|
|
payload = self.outputs_receiver.recv()
|
|
payload = self.outputs_receiver.recv()
|
|
if isinstance(payload, BaseException):
|
|
if isinstance(payload, BaseException):
|
|
raise payload
|
|
raise payload
|
|
else:
|
|
else:
|
|
batch_index, batch_outputs = payload
|
|
batch_index, batch_outputs = payload
|
|
|
|
+ logger.debug(f"{self.uid}, batch {batch_index}: got results")
|
|
|
|
|
|
# split batch into partitions for individual tasks
|
|
# split batch into partitions for individual tasks
|
|
batch_tasks = pending_batches.pop(batch_index)
|
|
batch_tasks = pending_batches.pop(batch_index)
|
|
task_sizes = [self.get_task_size(task) for task in batch_tasks]
|
|
task_sizes = [self.get_task_size(task) for task in batch_tasks]
|
|
outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
|
|
outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
|
|
|
|
+ logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
|
|
|
|
|
|
# dispatch results to futures
|
|
# dispatch results to futures
|
|
for task, task_outputs in zip(batch_tasks, outputs_per_task):
|
|
for task, task_outputs in zip(batch_tasks, outputs_per_task):
|
|
@@ -200,7 +202,6 @@ class TaskPool(TaskPoolBase):
|
|
raise TimeoutError()
|
|
raise TimeoutError()
|
|
|
|
|
|
batch_index, batch_inputs = self.batch_receiver.recv()
|
|
batch_index, batch_inputs = self.batch_receiver.recv()
|
|
- self.batch_received.set() # pool can now prepare next batch
|
|
|
|
batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
|
|
batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
|
|
return batch_index, batch_inputs
|
|
return batch_index, batch_inputs
|
|
|
|
|