justheuristic 3 лет назад
Родитель
Сommit
33358bc52b
2 измененных файлов с 29 добавлено и 11 удалено
  1. 22 5
      src/server/backend.py
  2. 7 6
      src/server/handler.py

+ 22 - 5
src/server/backend.py

@@ -18,7 +18,6 @@ class TransformerBackend(ModuleBackend):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, BloomBlock)
         self.memory_cache = memory_cache
-
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
         for name, buf in self.module.named_buffers():
@@ -28,11 +27,29 @@ class TransformerBackend(ModuleBackend):
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         attention_cache_handle = int(cache_metadata[0, 0].item())
-        current_sequence_length = int(cache_metadata[0, 1].item())
+        prefix_length = int(cache_metadata[0, 1].item())
+        hidden_states, *_ = inputs
+        assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+
         with self.memory_cache.use_cache(attention_cache_handle) as cache:
-            print('METADATA:', cache_metadata, "CACHE", cache.mean(), "CACHE ENTRIES:", len(self.memory_cache._allocated_tensors))
-            cache[...] += 1
-            return (inputs[0] + cache.flatten()[0],)
+            print('METADATA:', cache_metadata)
+            assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+            layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+            print(past_k.shape, past_v.shape)
+            hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+
+
+            # todo remove these debugprints
+            new_length = new_v.shape[1]
+            assert new_length > prefix_length
+            assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
+            assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
+            assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
+            assert torch.allclose(new_v[:, :past_v.shape[1]], past_v)
+            assert torch.allclose(new_k[:, :past_k.shape[1]], past_k)
+            cache[0, :, prefix_length: new_length, :] = new_k[:, prefix_length : new_length]
+            cache[1, :, prefix_length: new_length, :] = new_v[:, prefix_length: new_length]
+            return (hidden_states,)
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool

+ 7 - 6
src/server/handler.py

@@ -30,21 +30,22 @@ class TransformerConnectionHandler(ConnectionHandler):
             assert isinstance(backend, TransformerBackend)
 
             # prepare attention cache
-            hidden_size = backend.module.hidden_size
-            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, current_sequence_length]
-            cache_descriptor = TensorDescriptor(size=(1, MAX_LENGTH, hidden_size), dtype=torch.float32)
-            current_sequence_length = 0
+            num_heads = backend.module.self_attention.num_heads
+            head_dim = backend.module.self_attention.head_dim
+            cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
+            cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+            prefix_length = 0
 
             async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
                 while request.uid or request.tensors:  # iterate while user is willing to supply tensors
                     inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
                     print("INPUTS:", inputs)
                     assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
-                    cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, current_sequence_length
+                    cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
                     outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
                     yield runtime_pb2.ExpertResponse(tensors=outputs)
 
-                    current_sequence_length += inputs[1].shape[1]
+                    prefix_length += inputs[1].shape[1]
                     request = await(anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")