Răsfoiți Sursa

make all computations in handler

dbaranchuk 3 ani în urmă
părinte
comite
78d50a4a03
5 a modificat fișierele cu 57 adăugiri și 35 ștergeri
  1. 0 6
      src/bloom/block.py
  2. 16 3
      src/client/sequential_autograd.py
  3. 41 18
      src/server/handler.py
  4. 0 3
      src/server/server.py
  5. 0 5
      src/utils/misc.py

+ 0 - 6
src/bloom/block.py

@@ -18,7 +18,6 @@ from src.bloom.ops import (
     pre_process_alibi_for_pad,
     split_tensor_along_last_dim,
 )
-from src.utils.misc import is_dummy_batch
 
 
 class BloomAttention(nn.Module):
@@ -249,11 +248,6 @@ class BloomBlock(nn.Module):
         # MLP.
         output = self.mlp(layernorm_output, residual)
 
-        batch_size = hidden_states.shape[0]
-        if prompts is not None and not is_dummy_batch(prompts, batch_size):
-            pre_seq_len = prompts.shape[1]
-            output[:, :pre_seq_len] = output[:, :pre_seq_len] + prompts
-
         if use_cache:
             outputs = (output,) + outputs
         else:

+ 16 - 3
src/client/sequential_autograd.py

@@ -34,7 +34,13 @@ async def run_expert_forward(
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
 
-    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
     forward_inputs = nested_flatten(forward_inputs)
@@ -45,7 +51,7 @@ async def run_expert_forward(
     serialized_tensors = await asyncio.gather(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
         )
     )
 
@@ -69,7 +75,14 @@ async def run_expert_backward(
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))
 
     # Asynchronous serialization
     loop = asyncio.get_running_loop()

+ 41 - 18
src/server/handler.py

@@ -12,7 +12,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
-from src.utils.misc import DUMMY, is_dummy, is_dummy_batch, make_dummy_batch
+from src.utils.misc import DUMMY, is_dummy
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -128,11 +128,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )
+
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
             ]
         )
 
@@ -146,10 +153,17 @@ class TransformerConnectionHandler(ConnectionHandler):
 
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )
+
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
+            for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
         ]
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
@@ -200,17 +214,20 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 async def _rpc_forward(inputs, requested_backends):
     # Cast inputs to backend dtype
-    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
-    assert len(hidden_states) == 2 and hidden_states[0].ndim == 3
-    hidden_states, prompts = hidden_states
+    inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
+    assert len(inputs) == 2 and inputs[0].ndim == 3
+    hidden_states, prompts = inputs
 
     if is_dummy(prompts):
-        batch_size = hidden_states.shape[0]
-        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        pre_seq_len = prompts.shape[2]
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
-        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, prompt)
+        if not is_dummy(prompt):
+            hidden_states[:, :pre_seq_len] += prompt
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
             hidden_states.ndim == 3
@@ -225,11 +242,11 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     inputs = inputs.to(requested_backends[0].dtype)
     prompts = prompts.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-    batch_size = inputs.shape[0]
 
     if is_dummy(prompts):
-        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
+        prompts = [DUMMY] * len(requested_backends)
     else:
+        pre_seq_len = prompts.shape[2]
         prompts = [p.squeeze(0) for p in prompts.split(1)]
 
     # Run a forward chain to collect intermediate inputs
@@ -237,19 +254,25 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     inter_inputs = [inputs]
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-        (inputs,) = await backend.forward_pool.submit_task(inputs, prompt)
+        if not is_dummy(prompt):
+            inputs[:, :pre_seq_len] += prompt
+        (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
         inter_inputs.append(inputs)
 
     grad_prompts = []
     # Run a chain of requested backends
     for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
-        grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
-        assert isinstance(grads, (list, tuple)) and len(grads) == 2
-        grad_outputs, grad_prompt = grads
-        grad_prompts.append(grad_prompt[None])
-
-    is_dummy_grad_prompts = [is_dummy_batch(grad_param, batch_size) for grad_param in grad_prompts]
+        if not is_dummy(prompt):
+            inp[:, :pre_seq_len] += prompt
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+        else:
+            grad_prompts.append(DUMMY)
+
+    is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
     grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
     grads = [grad_outputs, grad_prompts]
     return grads

+ 0 - 3
src/server/server.py

@@ -212,9 +212,6 @@ class Server(threading.Thread):
                     BatchTensorDescriptor(
                         1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
                     ),
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
                 ),
                 kwargs_schema={},
                 outputs_schema=(

+ 0 - 5
src/utils/misc.py

@@ -1,12 +1,7 @@
 import torch
 
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
-make_dummy_batch = lambda x: torch.empty(x)
 
 
 def is_dummy(tensor: torch.Tensor):
     return tensor.numel() == 0
-
-
-def is_dummy_batch(tensor: torch.Tensor, batch_size: int):
-    return tensor.numel() == batch_size and tensor.ndim == 1