|
@@ -10,13 +10,14 @@ import torch.nn.functional as F
|
|
|
import torch.utils.checkpoint
|
|
|
from hivemind import use_hivemind_log_handler
|
|
|
from torch import nn
|
|
|
-from torch.nn import CrossEntropyLoss, LayerNorm
|
|
|
+from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
|
|
from transformers.file_utils import (
|
|
|
add_code_sample_docstrings,
|
|
|
add_start_docstrings,
|
|
|
add_start_docstrings_to_model_forward,
|
|
|
)
|
|
|
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
|
|
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, \
|
|
|
+ SequenceClassifierOutputWithPast
|
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
from transformers.utils import logging
|