浏览代码

add chained rpc_forward & rpc_backward

Dmitry Baranchuk 3 年之前
父节点
当前提交
4cb986f680
共有 2 个文件被更改,包括 206 次插入0 次删除
  1. 147 0
      src/server/handler.py
  2. 59 0
      tests/test_chained_forward_backward.py

+ 147 - 0
src/server/handler.py

@@ -7,6 +7,9 @@ from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import anext
+from hivemind.utils.streaming import split_for_streaming
+from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
+from hivemind.utils import as_aiter
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
@@ -67,6 +70,140 @@ class TransformerConnectionHandler(ConnectionHandler):
         finally:
             print("CLOSED RPC_INFERENCE")
 
+    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        # Parse request and prepare backends
+        hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_header(request)
+        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+
+        # Run a chain of requested backends 
+        for backend in requested_backends:
+            assert isinstance(hidden_states, (list, tuple))
+            assert (
+                len(hidden_states) == 1 and hidden_states[0].ndim == 3
+            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        
+        # Serialize the overall output and respond
+        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+        return runtime_pb2.ExpertResponse(tensors=[
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(
+                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+            )
+        ])
+
+    async def rpc_forward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        # Parse requests and prepare backends
+        uids_header, hidden_states = await self._gather_inputs(requests, context)
+        requested_uids = self._check_header_str(uids_header)
+        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+
+        # Run a chain of requested backends 
+        for backend in requested_backends:
+            assert isinstance(hidden_states, (list, tuple))
+            assert (
+                len(hidden_states) == 1 and hidden_states[0].ndim == 3
+            ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+            hidden_states = await backend.forward_pool.submit_task(*hidden_states)
+        
+        # Serialize the overall output
+        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+        serialized_output = [
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(
+                hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
+            )
+        ]
+
+        # Split the serialized_output for streaming and respond
+        output_split = [
+            part
+            for tensor in serialized_output
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+        ]
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])
+
+    async def rpc_backward(
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        # Parse requests and prepare backends
+        inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_header(request)
+        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+
+        # Run a forward chain to collect intermediate inputs
+        # Note that we do not forward for the last module since we do not need its output 
+        inter_inputs = [inputs]
+        for backend in requested_backends[:-1]:
+            assert (inputs.ndim == 3
+            ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+            inputs = await backend.forward_pool.submit_task(inputs)
+            assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
+            inputs = inputs[0]
+            inter_inputs.append(inputs)
+
+        # Run a chain of requested backends
+        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
+            inputs_and_grads = [inp, grads]
+            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
+            assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
+            grads = grads[0]
+        
+        # Serialize the overall grad_input and respond
+        return runtime_pb2.ExpertResponse(tensors=[
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(
+                [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
+            )
+        ])
+
+    async def rpc_backward_stream(
+        self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
+        inputs, grads = inputs_and_grads
+        requested_uids = self._check_header_str(uids_header)
+        requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
+
+        # Run a forward chain to collect intermediate inputs
+        # Note that we do not forward for the last module since we do not need its outputs 
+        inter_inputs = [inputs]
+        for backend in requested_backends[:-1]:
+            assert (inputs.ndim == 3
+            ), f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
+            inputs = await backend.forward_pool.submit_task(inputs)
+            assert (isinstance(inputs, (list, tuple)) and len(inputs) == 1)
+            inputs = inputs[0]
+            inter_inputs.append(inputs)
+
+         # Run a backward chain for requested backends
+        for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]):
+            inputs_and_grads = [inp, grads]
+            grads = await backend.backward_pool.submit_task(*inputs_and_grads)
+            assert (isinstance(grads, (list, tuple)) and len(grads) == 1)
+            grads = grads[0]
+        
+        # Serialize the overall grad_inputs
+        serialized_grad_inputs = [
+            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+            for result, proto in zip(
+                [grads], nested_flatten(requested_backends[0].grad_inputs_schema)
+            )
+        ]
+        # Split the serialized_grad_inputs for streaming and respond
+        output_split = [
+            part
+            for tensor in serialized_grad_inputs
+            for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+        ]
+
+        async for part in as_aiter(*output_split):
+            yield runtime_pb2.ExpertResponse(tensors=[part])
+
     def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""
         uids = (request.uid or "").split(CHAIN_DELIMITER)
@@ -77,6 +214,16 @@ class TransformerConnectionHandler(ConnectionHandler):
                 raise RuntimeError(f"Remote peer does not serve {uid}")
         return tuple(uids)
 
+    def _check_header_str(self, header) -> Sequence[ModuleUID]:
+        """Check that the first request to rpc_inference is valid"""
+        uids = (header or "").split(CHAIN_DELIMITER)
+        if not uids:
+            raise RuntimeError("User did not provide any uids")
+        for uid in uids:
+            if uid not in self.module_backends:
+                raise RuntimeError(f"Remote peer does not serve {uid}")
+        return tuple(uids)
+
     @contextlib.asynccontextmanager
     async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
         """Allocate memory caches for each transformer block, return cache handles"""

+ 59 - 0
tests/test_chained_forward_backward.py

@@ -0,0 +1,59 @@
+######
+# Warning:torch this test is a work in progress. It will be modified soon.
+# - if you want more stable tests, see test_block_exact_match
+# - if you want to figure out chained inference, ask yozh
+
+import os
+
+import hivemind
+import torch
+from hivemind.moe.expert_uid import ExpertInfo
+
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_block import RemoteTransformerBlock
+from src.dht_utils import get_remote_module
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+BLOCK_UID = os.environ.get("BLOCK_UID")
+if not BLOCK_UID:
+    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
+
+
+# seq_length > 128: rpc_forward_stream & rpc_backward_stream
+# seq_length <= 128: rpc_forward & rpc_backward
+def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    remote_block, = get_remote_module(dht, BLOCK_UID)
+    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+    
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+    ]        
+    inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
+    outputs_rpc = remote_block.forward(inputs)[0]
+    outputs_rpc.sum().backward()
+    grads_rpc = inputs.grad
+
+    inputs.grad = None
+    hidden_states = inputs
+    for ref_block in ref_blocks:
+        hidden_states = ref_block.forward(hidden_states)[0]
+    outputs_ref = hidden_states
+    outputs_ref.sum().backward()
+    grads_ref = inputs.grad
+
+    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
+    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)