Browse Source

reduce diff

Your Name 1 year ago
parent
commit
c665c42cf2

+ 2 - 2
src/petals/client/inference_session.py

@@ -34,7 +34,7 @@ class _ServerInferenceSession:
         span: RemoteSpanInfo,
         span_uids: Sequence[ModuleUID],
         inputs_queue: asyncio.Queue,
-        outputs_stream: AsyncIterator,
+        outputs_aiter: AsyncIterator,
         *block_kwargs,
         max_length: int,
     ):
@@ -42,7 +42,7 @@ class _ServerInferenceSession:
         self.span, self.span_uids = span, span_uids
         self.num_blocks = len(span_uids)
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
         self.session_id = str(uuid.uuid4())
         self.max_length = max_length
         self.stepped = False

+ 0 - 1
src/petals/client/sequential_autograd.py

@@ -233,7 +233,6 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
             prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches)
 
-        sequence_manager.rpc_info  # lazy init #TODO no longer needed
         outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 

+ 5 - 4
tests/test_priority_pool.py

@@ -4,8 +4,8 @@ import time
 
 import pytest
 import torch
-from hivemind.moe.server.runtime import Runtime
 
+from petals.server.server import RuntimeWithDeduplicatedPools
 from petals.server.task_pool import PrioritizedTaskPool
 
 
@@ -35,8 +35,7 @@ def test_priority_pools():
     runtime_ready = mp.Event()
     results_valid = mp.Event()
 
-    def dummy_pool_func(args, kwargs):
-        (x,) = args  # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly
+    def dummy_pool_func(x):
         time.sleep(0.1)
         y = x**2
         outputs_queue.put((x, y))
@@ -58,7 +57,9 @@ def test_priority_pools():
     proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid))
     proc.start()
 
-    runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
+    runtime = RuntimeWithDeduplicatedPools(
+        {str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0
+    )
     runtime.ready = runtime_ready
     runtime.start()