justheuristic 3 年之前
父節點
當前提交
81ea014bf5
共有 3 個文件被更改,包括 12 次插入4 次删除
  1. 2 2
      src/client/inference_session.py
  2. 7 1
      src/server/backend.py
  3. 3 1
      src/server/handler.py

+ 2 - 2
src/client/inference_session.py

@@ -64,14 +64,14 @@ class RemoteTransformerBlockInferenceSession:
         if self.closed:
             raise Exception("Session is closed, cannot perform step")
         # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
+        inputs = (new_hidden_states, torch.arange(len(new_hidden_states)))
         outputs_serialized = RemoteExpertWorker.run_coroutine(
             self._step(
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
                         serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
                     ],
                 )
             )

+ 7 - 1
src/server/backend.py

@@ -1,7 +1,8 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Sequence, Tuple
+from typing import Sequence, Tuple, Dict, Any
 
 import torch
+from hivemind import BatchTensorDescriptor
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.moe.server.task_pool import TaskPool
 
@@ -24,6 +25,7 @@ class TransformerBackend(ModuleBackend):
             assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
 
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
+        self.inference_schema = (self.args_schema, self.kwargs_schema, BatchTensorDescriptor((), dtype=torch.int64))
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         with torch.inference_mode():
@@ -56,3 +58,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool
+
+    def get_info(self) -> Dict[str, Any]:
+        """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
+        return dict(super().get_info(), inference_schema=self.inference_schema)

+ 3 - 1
src/server/handler.py

@@ -46,7 +46,9 @@ class TransformerConnectionHandler(ConnectionHandler):
             async with self._allocate_caches(requested_backends, batch_size) as cache_handles:
                 assert len(cache_handles) == len(requested_backends)
                 while request.tensors:  # iterate while user is willing to supply tensors
-                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+                    assert len(request.tensors) == 2, "Must specify hidden_states and input_ids" # TODO replace with schema
+                    hidden_states, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+                    print('OLOLO OLOLO I GOT HYPO IDS: ', hypo_ids)
 
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):