ソースを参照

intermediate changes

Pavel Samygin 2 年 前
コミット
0c6350da17

+ 2 - 2
src/client/__init__.py

@@ -1,6 +1,6 @@
-from src.client.dust_bank import DummyDustBank, DustBankBase
-from src.client.dusty_block import DustyRemoteBlock
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.priority_block import DustyRemoteBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
+from src.client.spending_policy import DummySpendingPolicy, SpendingPolicyBase

+ 0 - 18
src/client/dust_bank.py

@@ -1,18 +0,0 @@
-from abc import ABC, abstractmethod
-from functools import wraps
-
-from hivemind.p2p import StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.proto.runtime_pb2 import ExpertRequest
-from hivemind.utils import MSGPackSerializer, amap_in_executor
-
-
-class DustBankBase(ABC):
-    @abstractmethod
-    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
-        pass
-
-
-class DummyDustBank(DustBankBase):
-    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
-        return 0.0

+ 8 - 7
src/client/dusty_block.py → src/client/priority_block.py

@@ -10,19 +10,20 @@ from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.utils import MSGPackSerializer, amap_in_executor
 
-from src.client.dust_bank import DustBankBase
+from src.client.spending_policy import SpendingPolicyBase
 
 
+# TODO: (greenfatguy) remove later, left for now as example
 class DustyRemoteBlock(RemoteExpert):
-    def __init__(self, bank: DustBankBase, expert_info: ExpertInfo, p2p: P2P):
-        self._bank = bank
+    def __init__(self, bank: SpendingPolicyBase, expert_info: ExpertInfo, p2p: P2P):
+        self._spending_policy = bank
         super().__init__(expert_info, p2p)
 
     def _unary_request_wrapper(self, rpc_call: Callable, rpc_name: str):
         @wraps(rpc_call)
         async def rpc(input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
             meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
-            meta["__dust"] = self._bank.get_dust(input, rpc_name)
+            meta["__dust"] = self._spending_policy.get_points(input, rpc_name)
             input.metadata = MSGPackSerializer.dumps(meta)
             return await rpc_call(input, timeout)
 
@@ -37,7 +38,7 @@ class DustyRemoteBlock(RemoteExpert):
                 nonlocal is_meta_set
                 if not is_meta_set:
                     meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
-                    meta["__dust"] = self._bank.get_dust(chunk, rpc_name)
+                    meta["__dust"] = self._spending_policy.get_points(chunk, rpc_name)
                     chunk.metadata = MSGPackSerializer.dumps(meta)
                     is_meta_set = True
                 return chunk
@@ -46,7 +47,7 @@ class DustyRemoteBlock(RemoteExpert):
 
         return rpc
 
-    def _dustify_handler_stub(self, stub: StubBase) -> StubBase:
+    def _prioritize_handler_stub_calls(self, stub: StubBase) -> StubBase:
         for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
             if name.startswith("rpc"):
                 spec = inspect.getfullargspec(method)
@@ -68,4 +69,4 @@ class DustyRemoteBlock(RemoteExpert):
 
     @property
     def stub(self) -> StubBase:
-        return self._dustify_handler_stub(self._stub)
+        return self._prioritize_handler_stub_calls(self._stub)

+ 14 - 0
src/client/spending_policy.py

@@ -0,0 +1,14 @@
+from abc import ABC, abstractmethod
+
+from hivemind.proto.runtime_pb2 import ExpertRequest
+
+
+class SpendingPolicyBase(ABC):
+    @abstractmethod
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        pass
+
+
+class DummySpendingPolicy(SpendingPolicyBase):
+    def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
+        return 0.0

+ 17 - 53
src/server/backend.py

@@ -15,7 +15,6 @@ from hivemind.utils import InvalidStateError, get_logger
 
 from src.bloom.from_pretrained import BloomBlock
 from src.server.cache import MemoryCache
-from src.server.task_broker import DustBrokerBase, SimpleBroker
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -23,28 +22,30 @@ logger = get_logger(__file__)
 
 @dataclass(order=True)
 class PrioritizedTask:
-    value: int
+    priority: float
     task: Task = field(compare=False)
 
 
 class PrioritizedTaskPool(TaskPool):
-    def __init__(self, *args, broker: DustBrokerBase = SimpleBroker(), **kwargs):
+    def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.broker = broker
-        self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
-        self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
 
-    def submit_task(self, *args: torch.Tensor, dust: float = 0.0) -> Future:
+        assert self.min_batch_size == 1, "PriorityTaskPool supports no batching"
+
+        self.priority_queue = mp.Queue(maxsize=self.tasks._maxsize)
+        self.prioritized_task_queue = PriorityQueue(maxsize=self.tasks._maxsize)
+
+    def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> Future:
         f = super().submit_task(*args)
-        self.dust_queue.put(dust)
+        self.priority_queue.put(priority)
         return f
 
     def _priortize_tasks(self):
         """Infinite loop prioritizing incoming tasks"""
         while True:
             task = self.tasks.get(block=True)
-            dust = self.dust_queue.get(block=True)
-            self.priority_queue.put(PrioritizedTask(-self.broker(task, dust), task), block=True)
+            priority = self.priority_queue.get(block=True)
+            self.prioritized_task_queue.put(PrioritizedTask(priority, task), block=True)
 
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)
@@ -71,58 +72,19 @@ class PrioritizedTaskPool(TaskPool):
     # TODO: this is a copy-paste of the original method, except that we use different queue
     def iterate_minibatches(self, *args, **kwargs):
         """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-        batch = []
-        total_size = 0
-
-        while True:
-            if total_size >= self.min_batch_size and self.priority_queue.empty():
-                yield batch
-                batch = []
-                total_size = 0
-            try:
-                logger.debug(f"{self.name} getting next task")
-                task = self.priority_queue.get(timeout=self.timeout)
-            except Empty:
-                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
-                continue
-
-            task_size = self.get_task_size(task)
-
-            if total_size + task_size > self.max_batch_size:
-                yield batch
-                batch = []
-                total_size = 0
-
-            try:
-                if task.future.set_running_or_notify_cancel():
-                    batch.append(task)
-                    total_size += task_size
-            except InvalidStateError as e:
-                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
-
-
-class InferenceTaskPool(TaskPool):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-        assert self.min_batch_size == 1, "min_batch_size in InferenceTaskPool cannot be greater 1"
-
-    def iterate_minibatches(self, *args, **kwargs):
-        """Form minibatches by grouping one or more tasks together up to self.max_batch_size"""
-
         while True:
             try:
                 logger.debug(f"{self.name} getting next task")
-                task = self.tasks.get(timeout=self.timeout)
+                task: PrioritizedTask = self.prioritized_task_queue.get(timeout=self.timeout)
             except Empty:
                 logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
                 continue
 
             try:
-                if task.future.set_running_or_notify_cancel():
+                if task.task.future.set_running_or_notify_cancel():
                     yield [task]
             except InvalidStateError as e:
-                logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
+                logger.debug(f"Failed to add task to batch: {task.task.future} raised {e}")
 
 
 class TransformerBackend(ModuleBackend):
@@ -137,9 +99,11 @@ class TransformerBackend(ModuleBackend):
         for name, buf in self.module.named_buffers():
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
-        self.inference_pool = InferenceTaskPool(
+        self.inference_pool = PrioritizedTaskPool(
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
+        self.forward_pool = PrioritizedTaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
+        self.backward_pool = PrioritizedTaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:

+ 0 - 21
src/server/task_broker.py

@@ -1,21 +0,0 @@
-from abc import ABC, abstractmethod
-
-from hivemind.moe.server.task_pool import Task
-
-
-class DustBrokerBase(ABC):
-    """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
-
-    @abstractmethod
-    def __call__(self, task: Task, dust: float) -> float:
-        """Evaluates task value by the amout of dust promised"""
-        pass
-
-
-class SimpleBroker(DustBrokerBase):
-    """Simple implementation of DustBroker which counts amount of dust per task size"""
-
-    def __call__(self, task: Task, dust: float) -> float:
-        # TODO: this was taken from original task pool. Is is right?
-        task_size = len(task.args[0]) if task.args else 1
-        return dust / task_size

+ 19 - 0
src/server/task_prioritizer.py

@@ -0,0 +1,19 @@
+from abc import ABC, abstractmethod
+
+from hivemind.moe.server.task_pool import Task
+
+
+class TaskPrioritizerBase(ABC):
+    """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
+
+    @abstractmethod
+    def prioritize(self, task: Task, points: float, *args, **kwargs) -> float:
+        """Evaluates task value by the amout of points given"""
+        pass
+
+
+class DummyTaskPrioritizer(TaskPrioritizerBase):
+    """Simple implementation of DustBroker which counts amount of dust per task size"""
+
+    def __call__(self, task: Task, points: float, *args, **kwargs) -> float:
+        return 0.0

+ 7 - 7
tests/test_dust_payment.py

@@ -6,18 +6,18 @@ from hivemind.compression import deserialize_tensor_stream, deserialize_torch_te
 from hivemind.proto.runtime_pb2 import ExpertRequest
 from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
 
-from src.client.dust_bank import DustBankBase
-from src.client.dusty_block import DustyRemoteBlock
+from src.client.priority_block import DustyRemoteBlock
+from src.client.spending_policy import SpendingPolicyBase
 
 
-class DustBankTest(DustBankBase):
+class SpendingPolicyTest(SpendingPolicyBase):
     def __init__(self):
         self._p = {
             "rpc_single": 1,
             "rpc_stream": 2,
         }
 
-    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
+    def get_points(self, request: ExpertRequest, method_name: str) -> float:
         return self._p.get(method_name, -1)
 
 
@@ -41,7 +41,7 @@ class RemoteBlockTest(DustyRemoteBlock):
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_single():
-    remote = RemoteBlockTest(DustBankTest(), None, None)
+    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
     stub = remote.stub
     input = torch.randn(1, 2)
     request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
@@ -62,7 +62,7 @@ async def test_single():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_stream():
-    remote = RemoteBlockTest(DustBankTest(), None, None)
+    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
     stub = remote.stub
     input = torch.randn(2**21, 2)
 
@@ -92,7 +92,7 @@ async def test_stream():
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_no_wrapper():
-    remote = RemoteBlockTest(DustBankTest(), None, None)
+    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
     stub = remote.stub
 
     test = await stub.rpc_info("Test")

+ 0 - 0
tests/test_dust_pool.py → tests/test_priority_pool.py