|
@@ -1,7 +1,8 @@
|
|
|
"""Code for serving bloom blocks via hivemind-server"""
|
|
|
-from typing import Sequence, Tuple
|
|
|
+from typing import Sequence, Tuple, Dict, Any
|
|
|
|
|
|
import torch
|
|
|
+from hivemind import BatchTensorDescriptor
|
|
|
from hivemind.moe.server.module_backend import ModuleBackend
|
|
|
from hivemind.moe.server.task_pool import TaskPool
|
|
|
|
|
@@ -24,6 +25,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
|
|
|
|
|
|
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
|
|
|
+ self.inference_schema = (self.args_schema, self.kwargs_schema, BatchTensorDescriptor((), dtype=torch.int64))
|
|
|
|
|
|
def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
with torch.inference_mode():
|
|
@@ -56,3 +58,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
|
def get_pools(self) -> Sequence[TaskPool]:
|
|
|
return self.forward_pool, self.backward_pool, self.inference_pool
|
|
|
+
|
|
|
+ def get_info(self) -> Dict[str, Any]:
|
|
|
+ """Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
|
|
|
+ return dict(super().get_info(), inference_schema=self.inference_schema)
|