|
@@ -11,11 +11,11 @@ import torch
|
|
from hivemind import use_hivemind_log_handler
|
|
from hivemind import use_hivemind_log_handler
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
from hivemind.moe.server.task_pool import Task, TaskPool
|
|
from hivemind.moe.server.task_pool import Task, TaskPool
|
|
-from hivemind.utils import InvalidStateError, MPFuture, get_logger
|
|
|
|
|
|
+from hivemind.utils import InvalidStateError, get_logger
|
|
|
|
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
from src.bloom.from_pretrained import BloomBlock
|
|
from src.server.cache import MemoryCache
|
|
from src.server.cache import MemoryCache
|
|
-from src.server.task_broker import SimpleBroker, TaskBrokerBase
|
|
|
|
|
|
+from src.server.task_broker import DustBrokerBase, SimpleBroker
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
@@ -30,7 +30,7 @@ class PrioritizedTask:
|
|
|
|
|
|
|
|
|
|
class PrioritizedTaskPool(TaskPool):
|
|
class PrioritizedTaskPool(TaskPool):
|
|
- def __init__(self, *args, broker: TaskBrokerBase = SimpleBroker(), **kwargs):
|
|
|
|
|
|
+ def __init__(self, *args, broker: DustBrokerBase = SimpleBroker(), **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
super().__init__(*args, **kwargs)
|
|
self.broker = broker
|
|
self.broker = broker
|
|
self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
|
|
self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
|