|
@@ -20,17 +20,26 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
async def rpc_inference(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
-
|
|
|
- request = await anext(requests)
|
|
|
- backend = self.module_backends[request.uid]
|
|
|
- assert isinstance(backend, TransformerBackend)
|
|
|
-
|
|
|
- inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
-
|
|
|
- hidden_size = backend.module.hidden_size
|
|
|
- cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
|
|
|
- async with backend.memory_cache.allocate_cache(cache_descriptor) as handle:
|
|
|
- inputs.insert(0, torch.tensor([handle], dtype=torch.int64))
|
|
|
- outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
|
|
|
-
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=outputs)
|
|
|
+ """Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
+ try:
|
|
|
+ request = await anext(requests)
|
|
|
+ backend = self.module_backends[request.uid]
|
|
|
+ assert isinstance(backend, TransformerBackend)
|
|
|
+
|
|
|
+ # prepare attention cache
|
|
|
+ hidden_size = backend.module.hidden_size
|
|
|
+ cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, current_sequence_length]
|
|
|
+ cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
|
|
|
+ current_sequence_length = 0
|
|
|
+
|
|
|
+ async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
|
|
|
+ inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
|
|
|
+ print("INPUTS:", inputs)
|
|
|
+ assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
|
|
|
+ cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
|
|
|
+ outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=outputs)
|
|
|
+
|
|
|
+ current_sequence_length += inputs[1].shape[1]
|
|
|
+ finally:
|
|
|
+ print("CLOSED RPC_INFERENCE")
|