Explorar o código

delete DustyBlock, cosmetic changes

Pavel Samygin %!s(int64=2) %!d(string=hai) anos
pai
achega
a7395fe27c

+ 0 - 1
src/client/__init__.py

@@ -1,5 +1,4 @@
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
-from src.client.priority_block import DustyRemoteBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 0 - 72
src/client/priority_block.py

@@ -1,72 +0,0 @@
-from __future__ import annotations
-
-import inspect
-from functools import wraps
-from typing import AsyncIterator, Callable, Optional
-
-from hivemind.moe.client import RemoteExpert
-from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.p2p import P2P, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.utils import MSGPackSerializer, amap_in_executor
-
-from src.client.spending_policy import SpendingPolicyBase
-
-
-# TODO: (greenfatguy) remove later, left for now as example
-class DustyRemoteBlock(RemoteExpert):
-    def __init__(self, bank: SpendingPolicyBase, expert_info: ExpertInfo, p2p: P2P):
-        self._spending_policy = bank
-        super().__init__(expert_info, p2p)
-
-    def _unary_request_wrapper(self, rpc_call: Callable, rpc_name: str):
-        @wraps(rpc_call)
-        async def rpc(input: runtime_pb2.ExpertRequest, timeout: Optional[float] = None):
-            meta = MSGPackSerializer.loads(input.metadata) if input.metadata else {}
-            meta["__dust"] = self._spending_policy.get_points(input, rpc_name)
-            input.metadata = MSGPackSerializer.dumps(meta)
-            return await rpc_call(input, timeout)
-
-        return rpc
-
-    def _stream_request_wrapper(self, rpc_call: Callable, rpc_name: str):
-        @wraps(rpc_call)
-        async def rpc(input: AsyncIterator[runtime_pb2.ExpertRequest], timeout: Optional[float] = None):
-            is_meta_set = False
-
-            def _metadata_setter(chunk: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertRequest:
-                nonlocal is_meta_set
-                if not is_meta_set:
-                    meta = MSGPackSerializer.loads(chunk.metadata) if chunk.metadata else {}
-                    meta["__dust"] = self._spending_policy.get_points(chunk, rpc_name)
-                    chunk.metadata = MSGPackSerializer.dumps(meta)
-                    is_meta_set = True
-                return chunk
-
-            return await rpc_call(amap_in_executor(_metadata_setter, input), timeout)
-
-        return rpc
-
-    def _prioritize_handler_stub_calls(self, stub: StubBase) -> StubBase:
-        for name, method in inspect.getmembers(stub, predicate=inspect.ismethod):
-            if name.startswith("rpc"):
-                spec = inspect.getfullargspec(method)
-                # rpc callers has 3 arguments: stub, input and timeout
-                if len(spec.args) != 3:
-                    continue
-
-                input_type = spec.annotations[spec.args[1]]
-
-                if input_type is AsyncIterator[runtime_pb2.ExpertRequest]:
-                    setattr(stub, name, self._stream_request_wrapper(method, name))
-                elif input_type is runtime_pb2.ExpertRequest:
-                    setattr(stub, name, self._unary_request_wrapper(method, name))
-        return stub
-
-    @property
-    def _stub(self) -> StubBase:
-        return super().stub
-
-    @property
-    def stub(self) -> StubBase:
-        return self._prioritize_handler_stub_calls(self._stub)

+ 6 - 1
src/server/handler.py

@@ -138,7 +138,12 @@ class TransformerConnectionHandler(ConnectionHandler):
                             backend.inference_pool, PrioritizedTaskPool
                         ), "petals support only prioritized pools"
                         priority = self._prioritizer.prioritize(
-                            cache_metadata, hidden_states, hypo_ids, points=point_per_piece / len(requested_backends)
+                            cache_metadata,
+                            hidden_states,
+                            hypo_ids,
+                            points=point_per_piece / len(requested_backends),
+                            backend=backend,
+                            type="inference",
                         )
                         (hidden_states,) = await backend.inference_pool.submit_task(
                             cache_metadata, hidden_states, hypo_ids, priority=priority

+ 3 - 3
src/server/task_prioritizer.py

@@ -5,16 +5,16 @@ from hivemind.moe.server.task_pool import Task
 
 
 class TaskPrioritizerBase(ABC):
-    """Abstract class for DustBroker whose reponsibility is to evaluate task profit"""
+    """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
 
     @abstractmethod
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
-        """Evaluates task value by the amout of points given"""
+        """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
         pass
 
 
 class DummyTaskPrioritizer(TaskPrioritizerBase):
-    """Simple implementation of DustBroker which counts amount of dust per task size"""
+    """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
 
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
         return 0.0

+ 0 - 99
tests/test_point_system.py

@@ -1,99 +0,0 @@
-from typing import AsyncIterator, Optional
-
-import pytest
-import torch
-from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.proto.runtime_pb2 import ExpertRequest
-from hivemind.utils import MSGPackSerializer, amap_in_executor, iter_as_aiter, split_for_streaming
-
-from src.client.priority_block import DustyRemoteBlock
-from src.client.spending_policy import SpendingPolicyBase
-
-
-class SpendingPolicyTest(SpendingPolicyBase):
-    def __init__(self):
-        self._p = {
-            "rpc_single": 1,
-            "rpc_stream": 2,
-        }
-
-    def get_points(self, request: ExpertRequest, method_name: str) -> float:
-        return self._p.get(method_name, -1)
-
-
-class HandlerStubTest:
-    async def rpc_single(self, input: ExpertRequest, timeout: Optional[float] = None):
-        return input
-
-    async def rpc_stream(self, input: AsyncIterator[ExpertRequest], timeout: Optional[float] = None):
-        return input
-
-    async def rpc_info(self, input: str, timeout: Optional[float] = None):
-        return input
-
-
-class RemoteBlockTest(DustyRemoteBlock):
-    @property
-    def _stub(self):
-        return HandlerStubTest()
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_single():
-    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
-    stub = remote.stub
-    input = torch.randn(1, 2)
-    request = ExpertRequest(uid="expert1", tensors=[serialize_torch_tensor(input)])
-
-    print(stub)
-    out: ExpertRequest = await stub.rpc_single(request)
-
-    assert out.metadata != b""
-    assert len(out.tensors) == 1
-    assert torch.allclose(input, deserialize_torch_tensor(out.tensors[0]))
-
-    meta = MSGPackSerializer.loads(out.metadata)
-    assert isinstance(meta, dict)
-    assert "__dust" in meta
-    assert meta["__dust"] == 1
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_stream():
-    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
-    stub = remote.stub
-    input = torch.randn(2**21, 2)
-
-    split = (p for t in [serialize_torch_tensor(input)] for p in split_for_streaming(t, chunk_size_bytes=2**16))
-    output_generator = await stub.rpc_stream(
-        amap_in_executor(
-            lambda tensor_part: ExpertRequest(uid="expert2", tensors=[tensor_part]),
-            iter_as_aiter(split),
-        ),
-    )
-    outputs_list = [part async for part in output_generator]
-    assert len(outputs_list) == 2**5 * 8
-    assert outputs_list[0].metadata != b""
-    for i in range(1, len(outputs_list)):
-        assert outputs_list[i].metadata == b""
-
-    meta = MSGPackSerializer.loads(outputs_list[0].metadata)
-    assert isinstance(meta, dict)
-    assert "__dust" in meta
-    assert meta["__dust"] == 2
-
-    results = await deserialize_tensor_stream(amap_in_executor(lambda r: r.tensors, iter_as_aiter(outputs_list)))
-    assert len(results) == 1
-    assert torch.allclose(results[0], input)
-
-
-@pytest.mark.forked
-@pytest.mark.asyncio
-async def test_no_wrapper():
-    remote = RemoteBlockTest(SpendingPolicyTest(), None, None)
-    stub = remote.stub
-
-    test = await stub.rpc_info("Test")
-    assert test == "Test"