Artem Chumachenko 3 роки тому
батько
коміт
79a9ff2b2e
3 змінених файлів з 15 додано та 8 видалено
  1. 2 2
      src/client/inference_session.py
  2. 10 4
      src/server/backend.py
  3. 3 2
      src/server/handler.py

+ 2 - 2
src/client/inference_session.py

@@ -64,14 +64,14 @@ class RemoteTransformerBlockInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, torch.arange(len(new_hidden_states)))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                 )
             )

+ 10 - 4
src/server/backend.py

@@ -1,9 +1,9 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
-from typing import Sequence, Tuple
+from typing import Sequence, Tuple, Dict, Any
 
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import InvalidStateError, get_logger
@@ -56,19 +56,21 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = InferenceTaskPool(
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
+        self.inference_schema = (self.args_schema, self.kwargs_schema, BatchTensorDescriptor((), dtype=torch.int64))
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            hidden_states, hypo_ids = inputs  # todo: in future, it would be best to support attention mask here
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+                arange = torch.arange(prefix_length)
+                layer_past = past_k, past_v = cache[0, hypo_ids, arange], cache[1, hypo_ids, arange]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
                     hidden_states, layer_past=layer_past, use_cache=True
@@ -88,3 +90,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 3 - 2
src/server/handler.py

@@ -46,7 +46,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    assert len(request.tensors) == 2, "Must specify hidden_states and input_ids" # TODO replace with schema
+                    hidden_states, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
@@ -55,7 +56,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
+                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
                         assert isinstance(hidden_states, (list, tuple))
                         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3