artek0chumak před 3 roky
rodič
revize
389d220d2d

+ 2 - 2
src/client/remote_block.py

@@ -78,12 +78,12 @@ class RemoteTransformerBlockInferenceSession:
             if not next_input_message.uid and not next_input_message.tensors:
                 break  # this message means "done sending"
 
-    def step(self, new_hidden_states: torch.Tensor, batch_ids: torch.Tensor):
+    def step(self, new_hidden_states: torch.Tensor):
         """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states, batch_ids)
+        inputs = (new_hidden_states,)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(

+ 3 - 4
src/client/remote_sequential.py

@@ -140,12 +140,11 @@ class RemoteSequentialInferenceSession:
 
         return self
 
-    def step(self, inputs: torch.Tensor, batch_ids: torch.Tensor):
+    def step(self, inputs: torch.Tensor):
         assert not self.closed
         for session in self.active_sessions:
-            outputs = session.step(inputs, batch_ids)
-            assert outputs.shape[1:] == inputs.shape[1:], f"expected {inputs.shape[1:]}, got {outputs.shape[1:]}"
-            assert outputs.shape[0] == batch_ids.shape[0], f"expected {batch_ids.shape[0]}, got {outputs.shape[0]}"
+            outputs = session.step(inputs)
+            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
             inputs = outputs
         return inputs
 

+ 1 - 1
src/server/backend.py

@@ -1,5 +1,5 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Optional, Sequence, Tuple
+from typing import Sequence, Tuple
 
 import torch
 from hivemind.moe.server.module_backend import ModuleBackend

+ 0 - 8
src/server/cache.py

@@ -122,14 +122,6 @@ class MemoryCache:
         assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
         yield self._allocated_tensors[handle]
 
-    @staticmethod
-    def update_cache_via_batch_ids(cache: torch.Tensor, batch_ids: torch.Tensor) -> None:
-        new_cache_shape = cache.shape
-        new_cache_shape[1] = batch_ids.size(0)
-        new_cache = torch.zeros(new_cache_shape)
-        new_cache.scatter_(1, batch_ids, cache)
-        cache.copy_(new_cache)
-
 
 class AllocationFailed(Exception):
     pass