Răsfoiți Sursa

remove residuals in all RPCs

justheuristic 2 ani în urmă
părinte
comite
16c3c0bf3d

+ 10 - 3
src/client/remote_forward_backward.py

@@ -66,7 +66,7 @@ async def run_remote_forward(
     uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
-    Serializes input tensors and calls "rpc_forward" on a remote server.
+    Serializes input tensors and calls "rpc_forward" on a remote server, return block outputs (including residuals)
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
@@ -100,13 +100,16 @@ async def run_remote_forward(
         )
     )
 
-    # call RPC on remote server
+    # call RPC on remote server, receive last hidden states *without* the residual component
     size = sum(t.element_size() * t.nelement() for t in inputs)
     if size > MAX_UNARY_PAYLOAD_SIZE:
         deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
     else:
         deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
+    input_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
+    residual = deserialize_torch_tensor(serialized_tensors[0]) if input_was_compressed else inputs[0]
+    deserialized_outputs[0].add_(residual)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -121,7 +124,7 @@ async def run_remote_backward(
     **kwargs,
 ) -> Sequence[torch.Tensor]:
     """
-    Serializes grad outputs and calls "rpc_backward" on a remote server.
+    Serializes grad outputs and calls "rpc_backward" on a remote server, returns grad w.r.t. inputs (with residuals)
     Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
     but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
     """
@@ -146,9 +149,13 @@ async def run_remote_backward(
     )
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
+    # remote server backward returns gradients without the residual component
     if size > MAX_UNARY_PAYLOAD_SIZE:
         deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
     else:
         deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
+    grad_output_was_compressed = serialized_tensors[0].compression == runtime_pb2.CompressionType.NONE
+    residual_grad = deserialize_torch_tensor(serialized_tensors[0]) if grad_output_was_compressed else inputs[0]
+    deserialized_grad_inputs[0].add_(residual_grad)
     return deserialized_grad_inputs

+ 24 - 20
src/server/handler.py

@@ -121,6 +121,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         hidden_states, prompts, hypo_ids = [
                             deserialize_torch_tensor(tensor) for tensor in request.tensors
                         ]
+                        initial_hidden_states = hidden_states.clone()
 
                         # Cast inputs to backend dtype
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -169,14 +170,10 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 cache_metadata, hidden_states, hypo_ids, priority=priority
                             )
 
-                        # serialize and send last layer outputs
+                        # serialize and send last layer outputs without the residual component
+                        outputs = hidden_states - initial_hidden_states
                         yield runtime_pb2.ExpertResponse(
-                            tensors=[
-                                serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                                for result, proto in zip(
-                                    (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
-                                )
-                            ]
+                            tensors=self._serialize_outputs(outputs, requested_backends, metadata)
                         )
 
                         # prepare for next step
@@ -199,11 +196,11 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            outputs = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
             return runtime_pb2.ExpertResponse(
-                tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
+                tensors=self._serialize_outputs(outputs, requested_backends, metadata)
             )
 
     async def rpc_forward_stream(
@@ -221,12 +218,12 @@ class TransformerConnectionHandler(ConnectionHandler):
                 points, (float, int)
             ), f"rpc_forward_stream should have number of points as number or None, got {points}"
 
-            hidden_states = await _rpc_forward(
+            outputs = await _rpc_forward(
                 *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
             )
 
             # Split the serialized_output for streaming and respond to client
-            for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
+            for tensor in self._serialize_outputs(outputs, requested_backends, metadata):
                 for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
                     yield runtime_pb2.ExpertResponse(tensors=[part])
 
@@ -379,12 +376,15 @@ async def _rpc_forward(
     points: int = 0,
 ) -> torch.Tensor:
     """
-    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
+    Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream.
+    A forward pass computes transformer hidden states after applying all requested_backends without the residual part.
+    In other words, it returns the last hidden states minus the first hidden states provided by the user.
 
     :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
     :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
-    :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
+    :returns: hidden states after the last layer *without residuals*, [batch_size, seq_length, hid_size]
+    :note: this method returns (layerN(...layer1(inputs) - inputs) to reduce compression error
     """
     hidden_states, prompts = flat_tensors
     dtype = requested_backends[0].dtype
@@ -395,6 +395,7 @@ async def _rpc_forward(
         prompts = [DUMMY] * len(requested_backends)
     else:
         prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
+    initial_hidden_states = hidden_states.clone()
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
@@ -405,17 +406,14 @@ async def _rpc_forward(
         priority = prioritizer.prioritize(
             hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
         )
-        (hidden_states,) = await backend.forward_pool.submit_task(
-            hidden_states,
-            priority=priority,
-        )
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, priority=priority)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
 
-    # Serialize the overall output
-    return hidden_states
+    # Return the difference between last hidden states and input hidden states (remove the residual component)
+    return torch.sub(hidden_states, initial_hidden_states, out=initial_hidden_states)
 
 
 async def _rpc_backward(
@@ -424,10 +422,15 @@ async def _rpc_backward(
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
+    """
+    Backpropagate gradients through _rpc_forward, return gradients w.r.t. inputs and optional prompts without residuals
+    :note: like in rpc_forward, this method returns (grad_input - grad_output) for better compression
+    """
     inputs, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+    original_grad_outputs = grad_outputs.clone()
 
     if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
@@ -469,4 +472,5 @@ async def _rpc_backward(
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
-    return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]  # TODO un-duct-tape
+    grad_inputs = torch.sub(grad_prompts, original_grad_outputs, out=original_grad_outputs)  # remove residuals
+    return [grad_inputs] if is_dummy(grad_prompts) else [grad_inputs, grad_prompts]  # TODO un-duct-tape

+ 1 - 1
tests/test_block_exact_match.py

@@ -16,7 +16,7 @@ from src.dht_utils import get_remote_module
 
 
 @pytest.mark.forked
-def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+def test_remote_block_exact_match(atol_forward=3e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)