Forráskód Böngészése

basic chained inference (multiple blocks per one RPC call)

justheuristic 3 éve
szülő
commit
1cdf8a77fb
4 módosított fájl, 134 hozzáadás és 25 törlés
  1. 2 1
      README.md
  2. 2 2
      src/server/cache.py
  3. 66 22
      src/server/handler.py
  4. 64 0
      tests/test_chained_inference.py

+ 2 - 1
README.md

@@ -51,8 +51,9 @@ import torch
 import hivemind
 from src import get_remote_module
 
+
 dht = hivemind.DHT(
-    initial_peers=["/ip4/127.0.0.1/COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS"],
+    initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
     client_mode=True, start=True,
 )
 

+ 2 - 2
src/server/cache.py

@@ -8,7 +8,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
-from typing import Dict, Optional, Union
+from typing import Dict, Optional, Union, AsyncContextManager
 
 import hivemind
 import torch
@@ -54,7 +54,7 @@ class MemoryCache:
         self._handle_counter.value = value
 
     @contextlib.asynccontextmanager
-    async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
+    async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
         """
         Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
 

+ 66 - 22
src/server/handler.py

@@ -1,21 +1,24 @@
-from typing import AsyncIterator, Dict
+import contextlib
+from typing import AsyncIterator, Dict, Sequence
 
 import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten
+from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
 
+from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
 
 
 class TransformerConnectionHandler(ConnectionHandler):
     """Handles three request types: forward, backward and forward-incremental (inference)"""
+    module_backends: Dict[ModuleUID, TransformerBackend]
 
     def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
-        for module_backend in module_backends.values():
-            assert isinstance(module_backend, TransformerBackend)
         super().__init__(dht, module_backends)
+        for module_backend in self.module_backends.values():
+            assert isinstance(module_backend, TransformerBackend)
 
     async def rpc_inference(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@@ -24,28 +27,69 @@ class TransformerConnectionHandler(ConnectionHandler):
         try:
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
-            if not request.uid:
-                raise RuntimeError("User did not provide any uids.")
-            backend = self.module_backends[request.uid]
-            assert isinstance(backend, TransformerBackend)
-
-            # prepare attention cache
-            num_heads = backend.module.self_attention.num_heads
-            head_dim = backend.module.self_attention.head_dim
+            requested_uids = self._check_header(request)
+            requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+
             cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64)  # [cache_handle, prefix_length]
-            cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
             prefix_length = 0
 
-            async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
-                while request.uid or request.tensors:  # iterate while user is willing to supply tensors
-                    inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
-                    print("INPUTS:", inputs)
-                    assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
-                    cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
-                    outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
-                    yield runtime_pb2.ExpertResponse(tensors=outputs)
+            async with self._allocate_caches(requested_backends) as cache_handles:
+                assert len(cache_handles) == len(requested_backends)
+                while request.tensors:  # iterate while user is willing to supply tensors
+                    hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
+                    # run request tensors through all requested modules, update caches
+                    for backend, cache_handle in zip(requested_backends, cache_handles):
+                        cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
+                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3, \
+                            f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+
+                        hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
+                        assert isinstance(hidden_states, (list, tuple))
+                        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
 
-                    prefix_length += inputs[1].shape[1]
+                    # serialize and send last layer outputs
+                    yield runtime_pb2.ExpertResponse(tensors=[
+                        serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                        for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+                    ])
+
+                    # prepare for next step
+                    prefix_length += hidden_states[0].shape[1]
                     request = await (anext(requests))
         finally:
             print("CLOSED RPC_INFERENCE")
+
+    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
+        """Check that the first request to rpc_inference is valid"""
+        uids = (request.uid or '').split(CHAIN_DELIMITER)
+        if not uids:
+            raise RuntimeError("User did not provide any uids")
+        for uid in uids:
+            if uid not in self.module_backends:
+                raise RuntimeError(f"Remote peer does not serve {uid}")
+        return tuple(uids)
+
+    @contextlib.asynccontextmanager
+    async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
+        """Allocate memory caches for each transformer block, return cache handles"""
+        async with contextlib.AsyncExitStack() as stack:
+            handles = []
+            for backend in backends:
+                num_heads = backend.module.self_attention.num_heads
+                head_dim = backend.module.self_attention.head_dim
+
+                cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
+                # [key_or_value, batch_size, max_length, num_heads, head_dim]
+
+                handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
+
+            yield handles
+
+
+
+
+
+
+
+

+ 64 - 0
tests/test_chained_inference.py

@@ -0,0 +1,64 @@
+######
+# Warning:torch this test is a work in progress. It will be modified soon.
+# - if you want more stable tests, see test_block_exact_match
+# - if you want to figure out chained inference, ask yozh
+
+import os
+
+import hivemind
+import torch
+from hivemind.moe.expert_uid import ExpertInfo
+
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_block import RemoteTransformerBlock
+from src.dht_utils import get_remote_module
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+BLOCK_UID = os.environ.get("BLOCK_UID")
+if not BLOCK_UID:
+    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
+REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
+
+
+def test_remote_block_exact_match(atol_inference=1e-4):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    remote_block = get_remote_module(dht, BLOCK_UID)
+    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo('bloom6b3.3 bloom6b3.4', remote_block._info.peer_id)
+
+    inputs = torch.randn(1, 8, 4096)
+
+    outputs_inference = []
+    with remote_block.begin_inference_session() as sess:
+        for i in range(inputs.shape[1]):
+            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+    outputs_inference = torch.cat(outputs_inference, dim=1)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32)
+    ]
+    outputs_ref = []
+    caches = [None, None]
+    for i in range(inputs.shape[1]):
+        new_caches = []
+        hidden_states = inputs[:, i : i + 1, :]
+        for ref_block, cache in zip(ref_blocks, caches):
+            with torch.no_grad():
+                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
+                new_caches.append(new_cache)
+
+        outputs_ref.append(hidden_states)
+        caches = new_caches
+    outputs_ref = torch.cat(outputs_ref, dim=1)
+    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)