|
@@ -10,14 +10,17 @@ import torch.nn.functional as F
|
|
|
import torch.utils.checkpoint
|
|
|
from hivemind import use_hivemind_log_handler
|
|
|
from torch import nn
|
|
|
-from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
|
|
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
|
|
from transformers.file_utils import (
|
|
|
add_code_sample_docstrings,
|
|
|
add_start_docstrings,
|
|
|
add_start_docstrings_to_model_forward,
|
|
|
)
|
|
|
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, \
|
|
|
- SequenceClassifierOutputWithPast
|
|
|
+from transformers.modeling_outputs import (
|
|
|
+ BaseModelOutputWithPastAndCrossAttentions,
|
|
|
+ CausalLMOutputWithCrossAttentions,
|
|
|
+ SequenceClassifierOutputWithPast,
|
|
|
+)
|
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
from transformers.utils import logging
|