소스 검색

crutch fix for tests

justheuristic 3 년 전
부모
커밋
4e64597d37
1개의 변경된 파일53개의 추가작업 그리고 51개의 파일을 삭제
  1. 53 51
      src/server/handler.py

+ 53 - 51
src/server/handler.py

@@ -1,8 +1,9 @@
 import contextlib
-from typing import AsyncIterator, Dict, Sequence
+from typing import AsyncIterator, Dict, Sequence, Optional, List
 
 import torch
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
+from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, \
+    serialize_torch_tensor, MSGPackSerializer
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
@@ -34,7 +35,7 @@ class TransformerConnectionHandler(ConnectionHandler):
         try:
             print("OPENED RPC_INFERENCE")
             request = await anext(requests)
-            requested_uids = self._check_header(request)
+            requested_uids = self._check_uids(request.uid)
             requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
             batch_size = request.tensors[0].size[0] if request.tensors else 1
@@ -81,18 +82,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse request and prepare backends
-        inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        hidden_states = await _rpc_forward(inputs, requested_backends)
+        hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
-        # Serialize the overall output and respond
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
+        # Serialize output and respond to client
         return runtime_pb2.ExpertResponse(
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+                for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
             ]
         )
 
@@ -100,20 +101,20 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
         # Parse requests and prepare backends
-        uids_header, inputs = await self._gather_inputs(requests, context)
-        requested_uids = self._check_header_str(uids_header)
+        uid_str, flat_inputs = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uid_str)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        hidden_states = await _rpc_forward(inputs, requested_backends)
+        hidden_states = await _rpc_forward(flat_inputs, requested_backends)
+        assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
 
         # Serialize the overall output
-        assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
         serialized_output = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
+            for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
         ]
 
-        # Split the serialized_output for streaming and respond
+        # Split the serialized_output for streaming and respond to client
         output_split = [
             part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
         ]
@@ -122,11 +123,11 @@ class TransformerConnectionHandler(ConnectionHandler):
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
         # Parse requests and prepare backends
-        inputs, prompts, grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
-        requested_uids = self._check_header(request)
+        flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+        requested_uids = self._check_uids(request.uid)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
@@ -147,11 +148,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
 
-        uids_header, (inputs, prompts, grad_outputs) = await self._gather_inputs(requests, context)
-        requested_uids = self._check_header_str(uids_header)
+        uids_header, flat_tensors = await self._gather_inputs(requests, context)
+        requested_uids = self._check_uids(uids_header)
         requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
 
-        grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
+        grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
         assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
@@ -173,19 +174,9 @@ class TransformerConnectionHandler(ConnectionHandler):
         async for part in as_aiter(*output_split):
             yield runtime_pb2.ExpertResponse(tensors=[part])
 
-    def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
+    def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
         """Check that the first request to rpc_inference is valid"""
-        uids = (request.uid or "").split(CHAIN_DELIMITER)
-        if not uids:
-            raise RuntimeError("User did not provide any uids")
-        for uid in uids:
-            if uid not in self.module_backends:
-                raise RuntimeError(f"Remote peer does not serve {uid}")
-        return tuple(uids)
-
-    def _check_header_str(self, header) -> Sequence[ModuleUID]:
-        """Check that the first request to rpc_inference is valid"""
-        uids = (header or "").split(CHAIN_DELIMITER)
+        uids = (uids or "").split(CHAIN_DELIMITER)
         if not uids:
             raise RuntimeError("User did not provide any uids")
         for uid in uids:
@@ -212,32 +203,42 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield handles
 
 
-async def _rpc_forward(inputs, requested_backends):
-    # Cast inputs to backend dtype
-    inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
-    assert len(inputs) == 2 and inputs[0].ndim == 3
-    hidden_states, prompts = inputs
+async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend]) -> torch.Tensor:
+    """
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
-    if is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
-    else:
-        pre_seq_len = prompts.shape[2]
+    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
+    :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
+    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    """
+    hidden_states, *prompts = flat_tensors
+    dtype = requested_backends[0].dtype
+    # check parse input tensors and cast dtypes
+    hidden_states = hidden_states.to(dtype)
+    assert hidden_states.ndim == 3
+    assert len(prompts) <= len(requested_backends), f"Expected at most {len(requested_backends)} prompts, one per layer"
 
-    # Run a chain of requested backends
+    for i in range(len(prompts)):
+        if not is_dummy(prompts[i]):
+            assert prompts[i].ndim == 3, "prompts must have shape [batch or 1, seq_len or prefix, hidden_size]"
+            prompts[i] = prompts[i].to(dtype)
+    prompts.extend((DUMMY for _ in range(len(prompts), len(requested_backends))))  # add missing prompts
+
+    seq_length = hidden_states.shape[1]
+
+    # run forward pass for requested backends
     for backend, prompt in zip(requested_backends, prompts):
-        if not is_dummy(prompt):
-            hidden_states[:, :pre_seq_len] += prompt
         (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
+        if not is_dummy(prompt):
+            hidden_states[:, :min(seq_length, prompt.shape[1]), ...] += prompt
         assert isinstance(hidden_states, torch.Tensor)
-        assert (
-            hidden_states.ndim == 3
-        ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
+        assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
 
-    # Serialize the overall output
-    return [hidden_states]
+    return hidden_states
 
 
-async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
+async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_outputs: torch.Tensor, requested_backends):
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     prompts = prompts.to(requested_backends[0].dtype)
@@ -255,6 +256,7 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
+            inputs = inputs.clone() # TODO
             inputs[:, :pre_seq_len] += prompt
         (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)