handler.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import AsyncIterator, Dict
  2. import torch
  3. from hivemind import P2PContext, DHT, deserialize_torch_tensor, TensorDescriptor
  4. from hivemind.moe.server.connection_handler import ConnectionHandler
  5. from hivemind.proto import runtime_pb2
  6. from hivemind.utils.asyncio import anext
  7. from src.server.backend import TransformerBlockBackend
  8. class TransformerConnectionHandler(ConnectionHandler):
  9. """Handles three request types: forward, backward and forward-incremental (inference)"""
  10. def __init__(self, *args, **kwargs):
  11. super().__init__(*args, **kwargs)
  12. async def rpc_forward_incremental(
  13. self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
  14. ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
  15. request = await anext(requests)
  16. expert = self.experts[request.uid]
  17. assert isinstance(expert, TransformerBlockBackend)
  18. inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
  19. async with expert.memory_cache.allocate_cache(TensorDescriptor.from_tensor(torch.randn(3))):
  20. outputs = await self._process_inputs(inputs, expert.forward_pool, expert.outputs_schema)
  21. return runtime_pb2.ExpertResponse(tensors=outputs)
  22. # note: you may use self.experts[uid].memory_cache!
  23. # encode expert_uid as @model_name[starting_layer:finishing_layer]
  24. # - while not closed: read input embeddings, check input shapes, run inference, return batch of outputs, repeat
  25. # - receive and maintain a handle for attention cache here
  26. raise NotImplementedError()