|
@@ -23,17 +23,11 @@ import torch.nn.functional as F
|
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
|
|
from torch.utils.checkpoint import checkpoint, get_device_states, set_device_states
|
|
|
from transformers import AlbertConfig
|
|
|
-from transformers.file_utils import add_start_docstrings
|
|
|
from transformers.modeling_outputs import BaseModelOutput
|
|
|
-from transformers.modeling_utils import PreTrainedModel
|
|
|
from transformers.models.albert.modeling_albert import (
|
|
|
ACT2FN,
|
|
|
- ALBERT_START_DOCSTRING,
|
|
|
- AlbertForPreTraining,
|
|
|
AlbertLayerGroup,
|
|
|
AlbertMLMHead,
|
|
|
- AlbertModel,
|
|
|
- AlbertSOPHead,
|
|
|
AlbertTransformer,
|
|
|
)
|
|
|
from transformers.utils import logging
|
|
@@ -46,7 +40,7 @@ _TOKENIZER_FOR_DOC = "AlbertTokenizer"
|
|
|
|
|
|
class LeanAlbertConfig(AlbertConfig):
|
|
|
rotary_embedding_base: int = 10_000
|
|
|
- hidden_act_gated: bool = True
|
|
|
+ hidden_act_gated: bool = False
|
|
|
|
|
|
def __hash__(self):
|
|
|
return hash("\t".join(f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_")))
|