ソースを参照

basic cache checks (via debugprint)

justheuristic 3 年 前
コミット
a44cb84f06
2 ファイル変更28 行追加19 行削除
  1. 5 5
      src/server/backend.py
  2. 23 14
      src/server/handler.py

+ 5 - 5
src/server/backend.py

@@ -26,14 +26,14 @@ class TransformerBackend(ModuleBackend):
 
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
-    def inference_step(self, attention_cache_handle: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-
-        attention_cache_handle = int(attention_cache_handle.item())
-        print('HANDLE:', attention_cache_handle)
+    def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        attention_cache_handle = int(cache_metadata[0, 0].item())
+        current_sequence_length = int(cache_metadata[0, 1].item())
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
+            print('METADATA:', cache_metadata, "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
             print(inputs[0].shape, cache.shape)
             cache[...] += 1
-            return (inputs[0] + cache,)
+            return (inputs[0] + cache.flatten()[0],)
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool

+ 23 - 14
src/server/handler.py

@@ -20,17 +20,26 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_inference(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
-
-        request = await anext(requests)
-        backend = self.module_backends[request.uid]
-        assert isinstance(backend, TransformerBackend)
-
-        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-
-        hidden_size = backend.module.hidden_size
-        cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
-        async with backend.memory_cache.allocate_cache(cache_descriptor) as handle:
-            inputs.insert(0, torch.tensor([handle], dtype=torch.int64))
-            outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
-
-        yield runtime_pb2.ExpertResponse(tensors=outputs)
+        """Compute a single step of inference using attention cache; update attention cache accordingly."""
+        try:
+            request = await anext(requests)
+            backend = self.module_backends[request.uid]
+            assert isinstance(backend, TransformerBackend)
+
+            # prepare attention cache
+            hidden_size = backend.module.hidden_size
+            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, current_sequence_length]
+            cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
+            current_sequence_length = 0
+
+            async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
+                inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
+                print("INPUTS:", inputs)
+                assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
+                cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
+                outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
+                yield runtime_pb2.ExpertResponse(tensors=outputs)
+
+                current_sequence_length += inputs[1].shape[1]
+        finally:
+            print("CLOSED RPC_INFERENCE")