Forráskód Böngészése

simple dirty dust points system

Pavel Samygin 3 éve
szülő
commit
170a57aca6
4 módosított fájl, 124 hozzáadás és 28 törlés
  1. 83 0
      src/client/dust_bank.py
  2. 5 5
      src/server/backend.py
  3. 25 18
      src/server/handler.py
  4. 11 5
      src/server/task_broker.py

+ 83 - 0
src/client/dust_bank.py

@@ -0,0 +1,83 @@
+import inspect
+from abc import ABC, abstractmethod
+from functools import wraps
+from typing import AsyncIterator, Callable, Optional
+
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+from hivemind.proto.runtime_pb2 import ExpertRequest
+from hivemind.utils import MSGPackSerializer, amap_in_executor
+
+
+class DustBankBase(ABC):
+    @abstractmethod
+    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
+        pass
+
+
+class DummyDustBank(DustBankBase):
+    def get_dust(self, request: ExpertRequest, method_name: str) -> float:
+        return 0.0
+
+
+def _unary_request_wrapper(rpc_call: Callable, rpc_name: str, bank: DustBankBase):
+    @wraps(rpc_call)
+    async def rpc(stub: StubBase, input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
+        meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
+        meta.update("__dust", bank.get_dust(input, rpc_name))
+        input.metadata = MSGPackSerializer.dumps(meta)
+        return await rpc_call(stub, input, timeout)
+
+    return rpc
+
+
+def _stream_request_wrapper(rpc_call: Callable, rpc_name: str, bank: DustBankBase):
+    @wraps(rpc_call)
+    async def rpc(stub: StubBase, input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
+        is_meta_set = False
+
+        def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
+            nonlocal is_meta_set
+            if not is_meta_set:
+                meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
+                meta.update("__dust", bank.get_dust(chunk, rpc_name))
+                chunk.metadata = MSGPackSerializer.dumps(meta)
+                is_meta_set = True
+            return chunk
+
+        return await rpc_call(stub, amap_in_executor(_metadata_setter, input), timeout)
+
+    return rpc
+
+
+def _dustify_handler_stub(stub: StubBase, bank: DustBankBase) -> StubBase:
+    for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
+        if name.startswith("rpc"):
+            spec = inspect.getfullargspec(method)
+            # rpc callers has 3 arguments: stub, input and timeout
+            if len(spec.args) != 3:
+                continue
+
+            input_type = spec.annotations[spec.args[1]]
+
+            if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
+                setattr(stub, name, _stream_request_wrapper(method, name, bank))
+            elif input_type is runtime_pb2.ExpertRequest:
+                setattr(stub, name, _unary_request_wrapper(method, name, bank))
+    return stub
+
+
+def payment_wrapper(bank: DustBankBase) -> Callable:
+    def class_wrapper(cls):
+        d = cls.__dict__
+        if "stub" not in d or not isinstance(d["stub"], property):
+            raise TypeError('wrapped module class supposed to have property "stub"')
+        old_stub = d["stub"]
+
+        def _stub(self):
+            stub = old_stub.__get__(self)
+            return _dustify_handler_stub(stub, bank)
+
+        return type(cls.__name__, cls.__bases__, {k: v if k != "stub" else property(_stub) for k, v in d.items()})
+
+    return class_wrapper

+ 5 - 5
src/server/backend.py

@@ -33,20 +33,20 @@ class PrioritizedTaskPool(TaskPool):
     def __init__(self, *args, broker: TaskBrokerBase = SimpleBroker(), **kwargs):
         super().__init__(*args, **kwargs)
         self.broker = broker
-        self.pollen_queue = mp.Queue(maxsize=self.tasks.maxsize)
+        self.dust_queue = mp.Queue(maxsize=self.tasks.maxsize)
         self.priority_queue = PriorityQueue(maxsize=self.tasks.maxsize)
 
-    def submit_task(self, *args: torch.Tensor, pollen: float = 0.0) -> Future:
+    def submit_task(self, *args: torch.Tensor, dust: float = 0.0) -> Future:
         f = super().submit_task(*args)
-        self.pollen_queue.put(pollen)
+        self.dust_queue.put(dust)
         return f
 
     def _priortize_tasks(self):
         """Infinite loop prioritizing incoming tasks"""
         while True:
             task = self.tasks.get(block=True)
-            pollen = self.pollen_queue.get(block=True)
-            self.priority_queue.put(PrioritizedTask(-self.broker(task, pollen), task), block=True)
+            dust = self.dust_queue.get(block=True)
+            self.priority_queue.put(PrioritizedTask(-self.broker(task, dust), task), block=True)
 
     def run(self, *args, **kwargs):
         torch.set_num_threads(1)

+ 25 - 18
src/server/handler.py

@@ -1,5 +1,5 @@
 import contextlib
-from typing import AsyncIterator, Dict, List, Sequence, Union, Tuple, Iterable
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
 
 import torch
 from hivemind import (
@@ -7,19 +7,19 @@ from hivemind import (
     MSGPackSerializer,
     P2PContext,
     TensorDescriptor,
-    deserialize_torch_tensor,
     deserialize_tensor_stream,
+    deserialize_torch_tensor,
     nested_flatten,
     serialize_torch_tensor,
 )
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import anext, amap_in_executor, as_aiter
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
-from src.server.backend import MAX_LENGTH, TransformerBackend, PrioritizedTaskPool
+from src.server.backend import MAX_LENGTH, PrioritizedTaskPool, TransformerBackend
 from src.utils.misc import DUMMY, is_dummy
 
 
@@ -55,7 +55,6 @@ class TransformerConnectionHandler(ConnectionHandler):
         inputs = await deserialize_tensor_stream(tensors_stream)
         return expert_uid, inputs, metadata
 
-
     async def rpc_inference(
         self,
         requests: AsyncIterator[runtime_pb2.ExpertRequest],
@@ -80,7 +79,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
                     metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-                    pollen = metadata.get("pollen", 0.0)
+                    dust = metadata.get("__dust", 0.0)
 
                     # Cast inputs to backend dtype
                     hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
@@ -92,7 +91,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
                         if isinstance(backend.inference_pool, PrioritizedTaskPool):
-                            hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states, pollen)
+                            hidden_states = await backend.inference_pool.submit_task(
+                                cache_metadata, *hidden_states, dust
+                            )
                         else:
                             hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
                         assert isinstance(hidden_states, (list, tuple))
@@ -120,9 +121,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        pollen = metadata.get("pollen", 0.0)
+        dust = metadata.get("__dust", 0.0)
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=pollen)
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize output and respond to client
@@ -141,7 +142,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
+        hidden_states = await _rpc_forward(
+            *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+        )
         assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize the overall output
@@ -163,9 +166,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
         metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
-        pollen = metadata.get("pollen", 0.0)
+        dust = metadata.get("__dust", 0.0)
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=pollen)
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -191,7 +194,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
+        grads = await _rpc_backward(
+            *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
+        )
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
@@ -242,7 +247,9 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield handles
 
 
-async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0) -> torch.Tensor:
+async def _rpc_forward(
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
+) -> torch.Tensor:
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
@@ -269,7 +276,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
         if not is_dummy(prompt):
             hidden_states[:, :pre_seq_len] += prompt
         if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, pollen)
+            (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, dust)
         else:
             (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
@@ -282,7 +289,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
 
 
 async def _rpc_backward(
-    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0
+    *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
     inputs, grad_outputs, *prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
@@ -307,7 +314,7 @@ async def _rpc_backward(
         inter_inputs.append(inputs)
 
         if isinstance(backend.forward_pool, PrioritizedTaskPool):
-            (inputs,) = await backend.forward_pool.submit_task(inputs, pollen / 2.0)
+            (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
         else:
             (inputs,) = await backend.forward_pool.submit_task(inputs)
 
@@ -322,7 +329,7 @@ async def _rpc_backward(
     # Run a chain of requested backends
     for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
         if isinstance(backend.backward_pool, PrioritizedTaskPool):
-            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, pollen / 2.0)
+            (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
         else:
             (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
         assert isinstance(grad_outputs, torch.Tensor)

+ 11 - 5
src/server/task_broker.py

@@ -3,13 +3,19 @@ from abc import ABC, abstractmethod
 from hivemind.moe.server.task_pool import Task
 
 
-class TaskBrokerBase(ABC):
+class DustBrokerBase(ABC):
+    """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
+
     @abstractmethod
-    def __call__(self, task: Task, pollen: float) -> float:
+    def __call__(self, task: Task, dust: float) -> float:
+        """Evaluates task value by the amout of dust promised"""
         pass
 
 
-class SimpleBroker(TaskBrokerBase):
-    def __call__(self, task: Task, pollen: float) -> float:
+class SimpleBroker(DustBrokerBase):
+    """Simple implementation of DustBroker which counts amount of dust per task size"""
+
+    def __call__(self, task: Task, dust: float) -> float:
+        # TODO: this was taken from original task pool. Is is right?
         task_size = len(task.args[0]) if task.args else 1
-        return pollen / task_size
+        return dust / task_size