Explorar el Código

Alloc inference cache as one contiguous buffer (#160)

Alexander Borzunov hace 2 años
padre
commit
7cdc57a04b
Se han modificado 2 ficheros con 42 adiciones y 44 borrados
  1. 12 12
      src/petals/server/backend.py
  2. 30 32
      src/petals/server/handler.py

+ 12 - 12
src/petals/server/backend.py

@@ -48,25 +48,25 @@ class TransformerBackend(ModuleBackend):
             self.kwargs_schema,
         )
 
-    def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+    def inference_step(
+        self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, cache_metadata: torch.LongTensor
+    ) -> Tuple[torch.Tensor, ...]:
         num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
         with torch.inference_mode():
-            attention_cache_handle = int(cache_metadata[0, 0].item())
-            prefix_length = int(cache_metadata[0, 1].item())
-            (hidden_states, hypo_ids) = inputs
             assert (
                 hidden_states.ndim == 3
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+            cache_handle, rel_index, prefix_length = map(int, cache_metadata[0])
 
-            with self.memory_cache.use_cache(attention_cache_handle) as cache:
-                batch_size = cache.shape[1]
-                max_length = cache.numel() // (2 * batch_size * head_dim * num_heads)
-                assert isinstance(self.module, WrappedBloomBlock) and cache.shape[0] == 2 and cache.ndim == 3
+            with self.memory_cache.use_cache(cache_handle) as cache:
+                batch_size = cache.shape[2]
+                max_length = cache.shape[-1] // (head_dim * num_heads)
+                assert isinstance(self.module, WrappedBloomBlock) and cache.shape[1] == 2 and cache.ndim == 4
                 if not is_dummy(hypo_ids):
-                    assert hypo_ids.shape[0] == cache.shape[1]
-                    cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
-                key_cache = cache[0].view(batch_size, num_heads, head_dim, max_length)
-                value_cache = cache[1].view(batch_size, num_heads, max_length, head_dim)
+                    assert hypo_ids.shape[0] == batch_size
+                    cache[rel_index, :, :] = cache[rel_index, :, hypo_ids]  # in-place reorder cache by hypo ids
+                key_cache = cache[rel_index, 0].view(batch_size, num_heads, head_dim, max_length)
+                value_cache = cache[rel_index, 1].view(batch_size, num_heads, max_length, head_dim)
 
                 key_past = key_cache.flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
                 value_past = value_cache.flatten(0, 1)[:, :prefix_length, :]  # [batch * num_heads, kv_length, head_dim]

+ 30 - 32
src/petals/server/handler.py

@@ -119,12 +119,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                 batch_size = request.tensors[0].size[0] if request.tensors else 1
 
                 cache_metadata = torch.tensor(
-                    [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
-                )  # [cache_handle, prefix_length]
+                    [[-1, -1, -1] for _ in range(batch_size)], dtype=torch.int64
+                )  # [cache_handle, rel_index, prefix_length]
                 prefix_length = 0
 
-                async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
-                    assert len(cache_handles) == len(requested_backends)
+                async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handle:
                     while request.tensors:  # iterate while user is willing to supply tensors
                         hidden_states, prompts, hypo_ids = [
                             deserialize_torch_tensor(tensor) for tensor in request.tensors
@@ -151,14 +150,16 @@ class TransformerConnectionHandler(ConnectionHandler):
                             )
 
                         # run request tensors through all requested modules, update caches
-                        for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
+                        for rel_index, (backend, prompt) in enumerate(zip(requested_backends, prompts)):
                             if not is_dummy(prompt):
                                 hidden_states[:, : prompt.shape[1]] += prompt
                             if hidden_states.numel() == 0:
                                 continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
                                 # when user wants to pre-allocate cache or check that server *can* allocate that cache
 
-                            cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
+                            cache_metadata[:] = torch.tensor(
+                                [cache_handle, rel_index, prefix_length], dtype=torch.int64
+                            )
                             assert isinstance(
                                 hidden_states, torch.Tensor
                             ), f"hidden states must be tensor, got {type(hidden_states)}"
@@ -177,7 +178,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 type="inference",
                             )
                             (hidden_states,) = await backend.inference_pool.submit_task(
-                                cache_metadata, hidden_states, hypo_ids, priority=priority
+                                hidden_states, hypo_ids, cache_metadata, priority=priority
                             )
 
                         # serialize and send last layer outputs
@@ -343,33 +344,30 @@ class TransformerConnectionHandler(ConnectionHandler):
         return tuple(uids)
 
     @contextlib.asynccontextmanager
-    async def _allocate_caches(
+    async def _allocate_cache(
         self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
     ) -> Sequence[int]:
-        """Allocate memory caches for each transformer block, return cache handles"""
-        async with contextlib.AsyncExitStack() as stack:
-            handles = []
-            total_size = 0
-            backend = None
-            for backend in backends:
-                num_heads = backend.module.self_attention.num_heads
-                head_dim = backend.module.self_attention.head_dim
-                descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype)
-                # ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout
-                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
-                total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
-
-            gib = 1024**3
-            if backend is not None:
-                cur_size = backend.memory_cache.current_size_bytes
-                max_size = backend.memory_cache.max_size_bytes
-                friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
-                cache_stats = f"used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
-            else:
-                cache_stats = f"cache stats n/a"
-            logger.info(f"rpc_inference.alloc(total_size={total_size / gib:.2f} GiB), {cache_stats}")
-
-            yield handles
+        """Allocate memory cache for all transformer blocks, return cache handle"""
+
+        n_blocks = len(backends)
+        backend = backends[0]
+        n_heads = backend.module.self_attention.num_heads
+        head_dim = backend.module.self_attention.head_dim
+        descr = TensorDescriptor(size=(n_blocks, 2, batch_size, n_heads * head_dim * max_length), dtype=backend.dtype)
+        alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
+
+        gib = 1024**3
+        cur_size = backend.memory_cache.current_size_bytes
+        max_size = backend.memory_cache.max_size_bytes
+        friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
+        logger.info(
+            f"rpc_inference.wait_for_alloc(size={alloc_size / gib:.2f} GiB), "
+            f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
+        )
+
+        async with backend.memory_cache.allocate_cache(descr) as handle:
+            logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
+            yield handle
 
     def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
         friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]