|
@@ -17,7 +17,6 @@ from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArgume
|
|
|
from data import make_dataset
|
|
|
from huggingface_auth import authorize_with_huggingface
|
|
|
from lib.training.clipped_lamb import LambWithGradientClipping
|
|
|
-from lib.training.offload import OffloadOptimizer
|
|
|
|
|
|
|
|
|
logger = hivemind.get_logger(__name__)
|
|
@@ -142,9 +141,8 @@ class TrainingTask:
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- opt = OffloadOptimizer(
|
|
|
+ opt = LambWithGradientClipping(
|
|
|
optimizer_grouped_parameters,
|
|
|
- optim_cls=LambWithGradientClipping,
|
|
|
lr=training_args.learning_rate,
|
|
|
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
|
|
eps=training_args.adam_epsilon,
|