소스 검색

Return multibatch mode

Artem Chumachenko 3 년 전
부모
커밋
ade986ca58
1개의 변경된 파일49개의 추가작업 그리고 0개의 파일을 삭제
  1. 49 0
      src/client/remote_model.py

+ 49 - 0
src/client/remote_model.py

@@ -151,6 +151,55 @@ class DistributedBloomModel(BloomModel):
         )
 
 
+class DistributedBloomPrefix(DistributedBloomModel):
+    """DistributedBloomModel with prefix tokens for prompt tuning"""
+
+    def __init__(self, config):
+        super().__init__(config)
+        assert config.num_prefix_tokens > 0, "The number of prefix tokens must be > 0"
+        self.prefix_length = config.num_prefix_tokens
+
+        self.prompt_embeddings = nn.Embedding(self.prefix_length, config.hidden_size)
+        self.prefix_tokens = torch.arange(self.prefix_length).long()
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+        return prompts
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert (
+            input_ids is None or inputs_embeds is None
+        ), "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        batch_size = inputs_embeds.shape[0]
+
+        if attention_mask is not None:
+            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
+            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+        prompts = self.get_prompt(batch_size)
+        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
+
+        # Remove prefix
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
+        transformer_outputs["last_hidden_state"] = last_hidden_state
+        return transformer_outputs
+
+
 class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""