Explorar el Código

address comments & black & isort

dbaranchuk hace 3 años
padre
commit
3dfa95636e
Se han modificado 1 ficheros con 14 adiciones y 13 borrados
  1. 14 13
      src/client/remote_model.py

+ 14 - 13
src/client/remote_model.py

@@ -1,13 +1,12 @@
 # this code is in active development, interfaces may change
-import os
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple
 
 import hivemind
 import torch
 import torch.nn as nn
 from hivemind import get_logger, use_hivemind_log_handler
-
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+
 from src.bloom.model import (
     BloomConfig,
     BloomForCausalLM,
@@ -68,12 +67,18 @@ class DistributedBloomModel(BloomModel):
             p.requires_grad = value
 
     def forward(
-        self, 
-        input_ids: Optional[torch.LongTensor] = None, 
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
         attention_mask: Optional[torch.Tensor] = None,
-        **kwargs
+        **kwargs,
     ):
+        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         elif input_ids is not None:
@@ -124,7 +129,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         input_ids: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
         attention_mask: Optional[torch.Tensor] = None,
-        **kwargs
+        **kwargs,
     ):
         assert (
             input_ids is None or inputs_embeds is None
@@ -143,14 +148,10 @@ class DistributedBloomPrefix(DistributedBloomModel):
         prompts = self.get_prompt(batch_size)
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds,
-            attention_mask=attention_mask,
-            **kwargs
-        )
+        transformer_outputs = super().forward(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
 
         # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
         transformer_outputs["last_hidden_state"] = last_hidden_state
         return transformer_outputs