dbaranchuk 3 years ago
parent
commit
16ab6fba58
5 changed files with 39 additions and 37 deletions
  1. 4 3
      src/bloom/block.py
  2. 15 22
      src/client/remote_model.py
  3. 5 3
      src/client/sequential_autograd.py
  4. 10 9
      src/server/handler.py
  5. 5 0
      src/utils/misc.py

+ 4 - 3
src/bloom/block.py

@@ -18,7 +18,7 @@ from src.bloom.ops import (
     pre_process_alibi_for_pad,
     split_tensor_along_last_dim,
 )
-from src.utils.misc import DUMMY, is_dummy
+from src.utils.misc import is_dummy_batch
 
 
 class BloomAttention(nn.Module):
@@ -203,7 +203,7 @@ class BloomBlock(nn.Module):
     def forward(
         self,
         hidden_states,
-        prompts=DUMMY,
+        prompts=None,
         layer_past=None,
         attention_mask=None,
         head_mask=None,
@@ -249,7 +249,8 @@ class BloomBlock(nn.Module):
         # MLP.
         output = self.mlp(layernorm_output, residual)
 
-        if not is_dummy(prompts):
+        batch_size = hidden_states.shape[0]
+        if prompts is not None and not is_dummy_batch(prompts, batch_size):
             pre_seq_len = prompts.shape[1]
             output[:, :pre_seq_len] = output[:, :pre_seq_len] + prompts
 

+ 15 - 22
src/client/remote_model.py

@@ -1,5 +1,5 @@
 # this code is in active development, interfaces may change
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple
 
 import hivemind
 import torch
@@ -34,9 +34,7 @@ class DistributedBloomConfig(BloomConfig):
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
-    tuning_mode: Optional[
-        str
-    ] = None  # One of the available options for fine-tuning: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
 
 
 class DistributedBloomModel(BloomModel):
@@ -64,8 +62,7 @@ class DistributedBloomModel(BloomModel):
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
 
-        self.tuning_mode = config.tuning_mode
-        if self.tuning_mode and "ptune" in config.tuning_mode:
+        if config.tuning_mode and "ptune" in config.tuning_mode:
             assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
             self.pre_seq_len = config.pre_seq_len
             self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
@@ -73,11 +70,13 @@ class DistributedBloomModel(BloomModel):
 
             if config.tuning_mode == "deep_ptune":
                 self.intermediate_prompt_embeddings = nn.Embedding(
-                    self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
+                    self.pre_seq_len,
+                    config.num_hidden_layers * config.hidden_size
+                    # ^-- TODO: should be num_hidden_layers - 1
                 )
                 self.intermediate_prompt_embeddings.weight.data.zero_()
-        elif self.tuning_mode:
-            raise NotImplementedError(f"TODO: {self.tuning_mode}")
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
 
     def set_requires_grad(self, value):
         for p in self.parameters():
@@ -88,10 +87,10 @@ class DistributedBloomModel(BloomModel):
         prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
         prompts = self.prompt_embeddings(prefix_tokens)
 
-        if self.tuning_mode == "deep_ptune":
+        if self.config.tuning_mode == "deep_ptune":
             intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
             intermediate_prompts = intermediate_prompts.view(
-                batch_size, self.pre_seq_len, len(self.h) - 1, self.hidden_size
+                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
             )
             intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
         else:
@@ -124,7 +123,7 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        if self.tuning_mode and "ptune" in self.tuning_mode:
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -132,13 +131,13 @@ class DistributedBloomModel(BloomModel):
         hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        if "ptune" in self.tuning_mode:
+        if "ptune" in self.config.tuning_mode:
             hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
         else:
             hidden_states = self.h(hidden_states)
 
         # Remove prefix
-        if self.tuning_mode and "ptune" in self.tuning_mode:
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state
@@ -159,10 +158,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
 
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
         self.lm_head = LMHead(config, self.transformer.word_embeddings)
 
         # Initialize weights and apply final processing
@@ -192,10 +188,7 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 
     def __init__(self, config: DistributedBloomConfig):
         super().__init__(config)
-        if config.pre_seq_len > 0:
-            self.transformer = DistributedBloomPrefix(config)
-        else:
-            self.transformer = DistributedBloomModel(config)
+        self.transformer = DistributedBloomModel(config)
 
         # Initialize weights and apply final processing
         self.post_init()

+ 5 - 3
src/client/sequential_autograd.py

@@ -101,8 +101,9 @@ async def sequential_forward(
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
-    if not is_dummy(prompts):
-        assert len(prompts) == end_index - start_index + 1
+    assert is_dummy(prompts) or len(prompts) == len(
+        sequence_manager.block_uids
+    )  # should be n_layers - 1 but add extra prompts for convenience
 
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
@@ -163,6 +164,7 @@ async def sequential_backward(
                 grad_outputs, span_grad_prompts = await run_expert_backward(
                     span_uids, stub, sequence_manager.rpc_info, inputs_and_prompts, grad_outputs
                 )
+                grad_outputs = [grad_outputs]
                 grad_prompts.append(span_grad_prompts)
                 break
             except Exception as e:
@@ -256,7 +258,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
                 ctx.sequence_manager,
             )
         )
-        grad_input_batches = [output[0] for output in outputs]
+        grad_input_batches = [output[0][0] for output in outputs]
         grad_prompt_batches = [output[1] for output in outputs]
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)

+ 10 - 9
src/server/handler.py

@@ -12,7 +12,7 @@ from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import CHAIN_DELIMITER, ModuleUID
 from src.server.backend import MAX_LENGTH, TransformerBackend
-from src.utils.misc import DUMMY, is_dummy
+from src.utils.misc import DUMMY, is_dummy, is_dummy_batch, make_dummy_batch
 
 
 class TransformerConnectionHandler(ConnectionHandler):
@@ -205,7 +205,8 @@ async def _rpc_forward(inputs, requested_backends):
     hidden_states, prompts = hidden_states
 
     if is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
+        batch_size = hidden_states.shape[0]
+        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
 
     # Run a chain of requested backends
     for backend, prompt in zip(requested_backends, prompts):
@@ -224,29 +225,29 @@ async def _rpc_backward(inputs, prompts, grad_outputs, requested_backends):
     inputs = inputs.to(requested_backends[0].dtype)
     prompts = prompts.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
+    batch_size = inputs.shape[0]
 
     if is_dummy(prompts):
-        prompts = [DUMMY] * len(requested_backends)
+        prompts = [make_dummy_batch(batch_size)] * len(requested_backends)
 
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
     inter_inputs = [inputs]
     for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
         assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
-        inputs = await backend.forward_pool.submit_task(inputs, prompt)
-        assert isinstance(inputs, (list, tuple)) and len(inputs) == 1
-        inputs = inputs[0]
+        (inputs,) = await backend.forward_pool.submit_task(inputs, prompt)
+        assert isinstance(inputs, torch.Tensor)
         inter_inputs.append(inputs)
 
     grad_prompts = []
     # Run a chain of requested backends
-    for inp, prompt, backend in zip(inter_inputs[::-1], prompts[::-1], requested_backends[::-1]):
+    for inp, prompt, backend in zip(inter_inputs[::-1], prompts.flip(0), requested_backends[::-1]):
         grads = await backend.backward_pool.submit_task(inp, prompt, grad_outputs)
         assert isinstance(grads, (list, tuple)) and len(grads) == 2
         grad_outputs, grad_prompt = grads
-        grad_prompts.append(grad_prompt)
+        grad_prompts.append(grad_prompt[None])
 
-    is_dummy_grad_prompts = [is_dummy(grad_param) for grad_param in grad_prompts]
+    is_dummy_grad_prompts = [is_dummy_batch(grad_param, batch_size) for grad_param in grad_prompts]
     grad_prompts = torch.cat(grad_prompts, dim=0) if not any(is_dummy_grad_prompts) else DUMMY
     grads = [grad_outputs, grad_prompts]
     return grads

+ 5 - 0
src/utils/misc.py

@@ -1,7 +1,12 @@
 import torch
 
 DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter parameters
+make_dummy_batch = lambda x: torch.empty(x)
 
 
 def is_dummy(tensor: torch.Tensor):
     return tensor.numel() == 0
+
+
+def is_dummy_batch(tensor: torch.Tensor, batch_size: int):
+    return tensor.numel() == batch_size and tensor.ndim == 1