Browse Source

fix imports

justheuristic 3 years ago
parent
commit
54ba92dbe3
2 changed files with 4 additions and 3 deletions
  1. 1 1
      requirements.txt
  2. 3 2
      src/bloom/model.py

+ 1 - 1
requirements.txt

@@ -3,4 +3,4 @@ accelerate==0.10.0
 huggingface-hub==0.7.0
 bitsandbytes-cuda113==0.26.0
 https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
-https://github.com/huggingface/transformers/archive/ccc089780415445768bcfd3ac4418cec20353484.zip
+https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 3 - 2
src/bloom/model.py

@@ -10,13 +10,14 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
-from torch.nn import CrossEntropyLoss, LayerNorm
+from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
 from transformers.file_utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
     add_start_docstrings_to_model_forward,
 )
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, \
+    SequenceClassifierOutputWithPast
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.utils import logging