5
0
Эх сурвалжийг харах

make all computations in handler

dbaranchuk 3 жил өмнө
parent
commit
78d50a4a03

+ 0 - 6
src/bloom/block.py

@@ -18,7 +18,6 @@ from src.bloom.ops import (
     pre_process_alibi_for_pad,
     pre_process_alibi_for_pad,
     split_tensor_along_last_dim,
     split_tensor_along_last_dim,
 )
 )
-from src.utils.misc import is_dummy_batch
 
 
 
 
 class BloomAttention(nn.Module):
 class BloomAttention(nn.Module):
@@ -249,11 +248,6 @@ class BloomBlock(nn.Module):
         # MLP.
         # MLP.
         output = self.mlp(layernorm_output, residual)
         output = self.mlp(layernorm_output, residual)
 
 
-        batch_size = hidden_states.shape[0]
-        if prompts is not None and not is_dummy_batch(prompts, batch_size):
-            pre_seq_len = prompts.shape[1]
-            output[:, :pre_seq_len] = output[:, :pre_seq_len] + prompts
-
         if use_cache:
         if use_cache:
             outputs = (output,) + outputs
             outputs = (output,) + outputs
         else:
         else:

+ 16 - 3
src/client/sequential_autograd.py

@@ -34,7 +34,13 @@ async def run_expert_forward(
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
     forward_inputs = (inputs, kwargs)
     forward_inputs = (inputs, kwargs)
 
 
-    if not nested_compare(forward_inputs, rpc_info["forward_schema"]):
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
 
 
     forward_inputs = nested_flatten(forward_inputs)
     forward_inputs = nested_flatten(forward_inputs)
@@ -45,7 +51,7 @@ async def run_expert_forward(
     serialized_tensors = await asyncio.gather(
     serialized_tensors = await asyncio.gather(
         *(
         *(
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
             loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
-            for tensor, proto in zip(inputs, nested_flatten(rpc_info["forward_schema"]))
+            for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
         )
         )
     )
     )
 
 
@@ -69,7 +75,14 @@ async def run_expert_backward(
 
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
     inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
     inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu)))
-    backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"])))
+
+    # Modify forward_schema to support prompts
+    args_schema, kwargs_schema = rpc_info["forward_schema"]
+    # TODO: rm this assert when support arbitrary number of input tensors
+    assert len(args_schema) == 1 and len(inputs) == 2
+    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+
+    backward_schema = tuple(nested_flatten((forward_schema_with_prompts, rpc_info["outputs_schema"])))
 
 
     # Asynchronous serialization
     # Asynchronous serialization
     loop = asyncio.get_running_loop()
     loop = asyncio.get_running_loop()

+ 41 - 18
src/server/handler.py

@@ -12,7 +12,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
 from src.server.backend import MAX_LENGTH, TransformerBackend
-from src.utils.misc import DUMMY, is_dummy, is_dummy_batch, make_dummy_batch
+from src.utils.misc import DUMMY, is_dummy
 
 
 
 
 class TransformerConnectionHandler(ConnectionHandler):
 class TransformerConnectionHandler(ConnectionHandler):
@@ -128,11 +128,18 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
 
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )
+
         # Serialize the overall grad_input and respond
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
         return runtime_pb2.ExpertResponse(
             tensors=[
             tensors=[
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                 serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-                for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
+                for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
             ]
             ]
         )
         )
 
 
@@ -146,10 +153,17 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
         grads = await _rpc_backward(inputs, prompts, grad_outputs, requested_backends)
 
 
+        # Modify grad_inputs_schema to support grad_prompts
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) == 2
+        grad_inputs_schema_with_prompts = (
+            requested_backends[0].args_schema * len(grads),
+            requested_backends[0].kwargs_schema,
+        )
+
         # Serialize the overall grad_inputs
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
         serialized_grad_inputs = [
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
             serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
-            for result, proto in zip(grads, nested_flatten(requested_backends[0].grad_inputs_schema))
+            for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
         ]
         ]
         # Split the serialized_grad_inputs for streaming and respond
         # Split the serialized_grad_inputs for streaming and respond
         output_split = [
         output_split = [
@@ -200,17 +214,20 @@ class TransformerConnectionHandler(ConnectionHandler):
 
 
 async def _rpc_forward(inputs, requested_backends):
 async def _rpc_forward(inputs, requested_backends):
     # Cast inputs to backend dtype
     # Cast inputs to backend dtype
-    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
-    assert len(hidden_states) == 2 and hidden_states[0].ndim == 3
-    hidden_states, prompts = hidden_states
+    inputs = [tensor.to(requested_backends[0].dtype) for tensor in inputs]
+    assert len(inputs) == 2 and inputs[0].ndim == 3
+    hidden_states, prompts = inputs
 
 
     if is_dummy(prompts):
     if is_dummy(prompts):
-        batch_size = hidden_states.shape[0]
-        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
+        prompts = [DUMMY] * len(requested_backends)
+    else:
+        pre_seq_len = prompts.shape[2]
 
 
     # Run a chain of requested backends
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
     for backend, prompt in zip(requested_backends, prompts):
-        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, prompt)
+        if not is_dummy(prompt):
+            hidden_states[:, :pre_seq_len] += prompt
+        (hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
         assert isinstance(hidden_states, torch.Tensor)
         assert isinstance(hidden_states, torch.Tensor)
         assert (
         assert (
             hidden_states.ndim == 3
             hidden_states.ndim == 3
@@ -225,11 +242,11 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     inputs = inputs.to(requested_backends[0].dtype)
     inputs = inputs.to(requested_backends[0].dtype)
     prompts = prompts.to(requested_backends[0].dtype)
     prompts = prompts.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
-    batch_size = inputs.shape[0]
 
 
     if is_dummy(prompts):
     if is_dummy(prompts):
-        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
+        prompts = [DUMMY] * len(requested_backends)
     else:
     else:
+        pre_seq_len = prompts.shape[2]
         prompts = [p.squeeze(0) for p in prompts.split(1)]
         prompts = [p.squeeze(0) for p in prompts.split(1)]
 
 
     # Run a forward chain to collect intermediate inputs
     # Run a forward chain to collect intermediate inputs
@@ -237,19 +254,25 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     inter_inputs = [inputs]
     inter_inputs = [inputs]
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-        (inputs,) = await backend.forward_pool.submit_task(inputs, prompt)
+        if not is_dummy(prompt):
+            inputs[:, :pre_seq_len] += prompt
+        (inputs,) = await backend.forward_pool.submit_task(inputs)
         assert isinstance(inputs, torch.Tensor)
         assert isinstance(inputs, torch.Tensor)
         inter_inputs.append(inputs)
         inter_inputs.append(inputs)
 
 
     grad_prompts = []
     grad_prompts = []
     # Run a chain of requested backends
     # Run a chain of requested backends
     for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
     for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
-        grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
-        assert isinstance(grads, (list, tuple)) and len(grads) == 2
-        grad_outputs, grad_prompt = grads
-        grad_prompts.append(grad_prompt[None])
-
-    is_dummy_grad_prompts = [is_dummy_batch(grad_param, batch_size) for grad_param in grad_prompts]
+        if not is_dummy(prompt):
+            inp[:, :pre_seq_len] += prompt
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
+        assert isinstance(grad_outputs, torch.Tensor)
+        if not is_dummy(prompt):
+            grad_prompts.append(grad_outputs[:, :pre_seq_len].unsqueeze(0))
+        else:
+            grad_prompts.append(DUMMY)
+
+    is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
     grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
     grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
     grads = [grad_outputs, grad_prompts]
     grads = [grad_outputs, grad_prompts]
     return grads
     return grads

+ 0 - 3
src/server/server.py

@@ -212,9 +212,6 @@ class Server(threading.Thread):
                     BatchTensorDescriptor(
                     BatchTensorDescriptor(
                         1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
                         1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
                     ),
                     ),
-                    BatchTensorDescriptor(
-                        1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
-                    ),
                 ),
                 ),
                 kwargs_schema={},
                 kwargs_schema={},
                 outputs_schema=(
                 outputs_schema=(

+ 0 - 5
src/utils/misc.py

@@ -1,12 +1,7 @@
 import torch
 import torch
 
 
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
-make_dummy_batch = lambda x: torch.empty(x)
 
 
 
 
 def is_dummy(tensor: torch.Tensor):
 def is_dummy(tensor: torch.Tensor):
     return tensor.numel() == 0
     return tensor.numel() == 0
-
-
-def is_dummy_batch(tensor: torch.Tensor, batch_size: int):
-    return tensor.numel() == batch_size and tensor.ndim == 1