|
@@ -1,13 +1,12 @@
|
|
|
# this code is in active development, interfaces may change
|
|
|
-import os
|
|
|
-from typing import List, Optional, Tuple, Union
|
|
|
+from typing import List, Optional, Tuple
|
|
|
|
|
|
import hivemind
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
-
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
|
+
|
|
|
from src.bloom.model import (
|
|
|
BloomConfig,
|
|
|
BloomForCausalLM,
|
|
@@ -68,12 +67,18 @@ class DistributedBloomModel(BloomModel):
|
|
|
p.requires_grad = value
|
|
|
|
|
|
def forward(
|
|
|
- self,
|
|
|
- input_ids: Optional[torch.LongTensor] = None,
|
|
|
+ self,
|
|
|
+ input_ids: Optional[torch.LongTensor] = None,
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
+ assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
|
|
|
+
|
|
|
+ for k, v in kwargs.items():
|
|
|
+ if not (v is None or v is False):
|
|
|
+ logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
|
|
|
+
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
elif input_ids is not None:
|
|
@@ -124,7 +129,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
assert (
|
|
|
input_ids is None or inputs_embeds is None
|
|
@@ -143,14 +148,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
prompts = self.get_prompt(batch_size)
|
|
|
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
|
|
|
- transformer_outputs = super().forward(
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
- attention_mask=attention_mask,
|
|
|
- **kwargs
|
|
|
- )
|
|
|
+ transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
|
|
|
|
|
|
# Remove prefix
|
|
|
- last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
|
|
|
+ last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
|
|
|
transformer_outputs["last_hidden_state"] = last_hidden_state
|
|
|
return transformer_outputs
|
|
|
|