Browse Source

introduce hypo_ids

Artem Chumachenko 3 năm trước cách đây
mục cha
commit
1afd59a071
3 tập tin đã thay đổi với 17 bổ sung7 xóa
  1. 1 1
      src/client/inference_session.py
  2. 13 3
      src/server/backend.py
  3. 3 3
      src/server/handler.py

+ 1 - 1
src/client/inference_session.py

@@ -74,7 +74,7 @@ class RemoteTransformerBlockInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, torch.arange(len(new_hidden_states)))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(

+ 13 - 3
src/server/backend.py

@@ -1,9 +1,13 @@
 """Code for serving bloom blocks via hivemind-server"""
 from queue import Empty
+<<<<<<< HEAD
 from typing import Optional, Sequence, Tuple
+=======
+from typing import Sequence, Tuple, Dict, Any
+>>>>>>> introduce hypo_ids
 
 import torch
-from hivemind import use_hivemind_log_handler
+from hivemind import use_hivemind_log_handler, BatchTensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils import InvalidStateError, get_logger
@@ -55,19 +59,21 @@ class TransformerBackend(ModuleBackend):
             self.inference_step, max_batch_size=self.forward_pool.max_batch_size, name=f"{self.name}_inference"
         )
         self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
+        self.inference_schema = (self.args_schema, self.kwargs_schema, BatchTensorDescriptor((), dtype=torch.int64))
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
-            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            hidden_states, hypo_ids = inputs  # todo: in future, it would be best to support attention mask here
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+                arange = torch.arange(prefix_length)
+                layer_past = past_k, past_v = cache[0, hypo_ids, arange], cache[1, hypo_ids, arange]
                 print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
                 hidden_states, (new_k, new_v) = self.module.forward(
                     hidden_states, layer_past=layer_past, use_cache=True
@@ -85,3 +91,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 3 - 3
src/server/handler.py

@@ -64,8 +64,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-                    length_increment = hidden_states[0].shape[1]  # how many tokens are added this step (in each seq)
+                    assert len(request.tensors) == 2, "Must specify hidden_states and input_ids" # TODO replace with schema
+                    hidden_states, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                     if prefix_length + length_increment > max_length:
                         raise ValueError(
@@ -83,7 +83,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                             len(hidden_states) == 1 and hidden_states[0].ndim == 3
                         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
+                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, hidden_states, hypo_ids)
                         assert isinstance(hidden_states, (list, tuple))
                         assert len(hidden_states) == 1 and hidden_states[0].ndim == 3