فهرست منبع

Upgrade to using hivemind.optim.experimental

Aleksandr Borzunov 3 سال پیش
والد
کامیت
64dee420da
8فایلهای تغییر یافته به همراه296 افزوده شده و 81 حذف شده
  1. 2 31
      arguments.py
  2. 16 27
      callback.py
  3. 1 1
      lib/training/hf_trainer.py
  4. 250 0
      lib/training/lamb_8bit.py
  5. 2 1
      requirements.txt
  6. 8 7
      run_aux_peer.py
  7. 1 1
      run_trainer_tpu.py
  8. 16 13
      task.py

+ 2 - 31
arguments.py

@@ -64,45 +64,16 @@ class CollaborativeArguments:
         default=16384,
         metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
     )
-    target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
-    bandwidth: float = field(
-        default=100.0,
-        metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
-    )
-    averaging_expiration: float = field(
+    matchmaking_time: float = field(
         default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
     )
     averaging_timeout: float = field(
-        default=300, metadata={"help": "Give up on averaging step after this many seconds"}
-    )
-    min_refresh_period: float = field(
-        default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
-    )
-    max_refresh_period: float = field(
-        default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
-    )
-    default_refresh_period: float = field(
-        default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
-    )
-    expected_drift_peers: float = field(
-        default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
-    )
-    expected_drift_rate: float = field(
-        default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
-    )
-    performance_ema_alpha: float = field(
-        default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
-    )
-    metadata_expiration: float = field(
-        default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
+        default=120, metadata={"help": "Give up on averaging step after this many seconds"}
     )
     reuse_grad_buffers: bool = field(default=True, metadata={
         "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
                 "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
                 "advanced techniques (e.g. applying custom loss scaler)"})
-    request_timeout: float = field(
-        default=10, metadata={"help": "Timeout for averager requests (loading state, joining groups)"},
-    )
 
 
 @dataclass

+ 16 - 27
callback.py

@@ -57,18 +57,18 @@ class CollaborativeCallback(transformers.TrainerCallback):
         if state.log_history:
             self.loss += state.log_history[-1]["loss"]
             self.steps += 1
-            if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
-                self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
+            if self.collaborative_optimizer.local_epoch != self.last_reported_collaboration_step:
+                self.last_reported_collaboration_step = self.collaborative_optimizer.local_epoch
                 self.total_samples_processed += self.samples
-                samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
+                samples_per_second = self.collaborative_optimizer.tracker.performance_ema.samples_per_second
                 statistics = LocalMetrics(
-                    step=self.collaborative_optimizer.local_step,
+                    step=self.collaborative_optimizer.local_epoch,
                     samples_per_second=samples_per_second,
                     samples_accumulated=self.samples,
                     loss=self.loss,
                     mini_steps=self.steps,
                 )
-                logger.info(f"Step {self.collaborative_optimizer.local_step}")
+                logger.info(f"Step {self.collaborative_optimizer.local_epoch}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 logger.info(f"Performance: {samples_per_second} samples per second.")
                 if self.steps:
@@ -76,19 +76,19 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
                 self.loss = 0
                 self.steps = 0
-                if self.collaborative_optimizer.is_synchronized:
+                if self.collaborative_optimizer.local_epoch == self.collaborative_optimizer.tracker.global_epoch:
                     self.dht.store(
-                        key=self.collaborative_optimizer.prefix + "_metrics",
+                        key=self.collaborative_optimizer.run_id + "_metrics",
                         subkey=self.task.local_public_key,
                         value=statistics.dict(),
                         expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
                         return_future=True,
                     )
                 if self.backup_every_steps is not None and \
-                        self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
+                        self.collaborative_optimizer.local_epoch % self.backup_every_steps == 0:
                     self.backup_state()
 
-        self.samples = self.collaborative_optimizer.local_samples_accumulated
+        self.samples = self.collaborative_optimizer.grad_averager.local_samples_accumulated
 
         return control
 
@@ -106,8 +106,8 @@ class CollaborativeCallback(transformers.TrainerCallback):
             {
                 "model": self.task.model.state_dict(),
                 "training": self.collaborative_optimizer.state_dict(),
-                "scheduler": self.collaborative_optimizer.scheduler.state_dict(),
-                "local_step": self.collaborative_optimizer.local_step,
+                "scheduler": self.collaborative_optimizer.state_averager.scheduler.state_dict(),
+                "local_epoch": self.collaborative_optimizer.local_epoch,
             },
             self.state_path,
         )
@@ -115,24 +115,13 @@ class CollaborativeCallback(transformers.TrainerCallback):
     @torch.no_grad()
     def restore_from_backup(self, path, check_step=False):
         state = torch.load(path)
-        current_step = self.collaborative_optimizer.local_step
-        backup_step = state['training']['state'][0]['step'] #TODO FIX THIS, use state['local_step']
+        current_step = self.collaborative_optimizer.local_epoch
+        backup_step = state['local_epoch']
         if not check_step or backup_step >= current_step:
-            if (
-                "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
-                in state["model"]
-            ):
-                del state["model"][
-                    "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
-                ]
-                del state["model"][
-                    "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.sin"
-                ]
-            if "scheduler" in state:
-                self.collaborative_optimizer.scheduler.load_state_dict(state["scheduler"])
-            self.collaborative_optimizer.load_state_dict(state["training"])
-            self.collaborative_optimizer.averager.local_step = backup_step
             self.task.model.load_state_dict(state["model"], strict=False)
+            self.collaborative_optimizer.load_state_dict(state["training"])
+            self.collaborative_optimizer.state_averager.scheduler.load_state_dict(state["scheduler"])
+            self.collaborative_optimizer.state_averager.local_epoch = backup_step
             logger.info("Restored from a backup")
         else:
             logger.info("Bypassed restoring state from local backup: backup state is too old.")

+ 1 - 1
lib/training/hf_trainer.py

@@ -35,7 +35,7 @@ class CollaborativeHFTrainer(Trainer):
     def _wrap_model(self, model, training=True):
         # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
         return IgnoreGradManipulations(super()._wrap_model(model, training=training),
-                                       override_zero_grad=self.collaborative_optimizer.reuse_grad_buffers)
+                                       override_zero_grad=self.collaborative_optimizer.grad_averager.reuse_grad_buffers)
 
 
 class NoOpScheduler(LRSchedulerBase):

+ 250 - 0
lib/training/lamb_8bit.py

@@ -0,0 +1,250 @@
+import math
+from typing import Dict, Any, Optional
+
+import torch
+
+from torch_optimizer.types import Betas2, Params
+from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+__all__ = ('CPULAMB8Bit',)
+
+
+class CPULAMB8Bit(Optimizer2State):
+    r"""
+    Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form.
+    The LAMB optimizer and block-wise quantization are described in the following papers:
+    - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962
+    - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861
+    This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb
+    - bias correction defaults to False because paper v3 does not use debiasing
+    - it has baked in clipping by global max_grad_norm
+    Arguments:
+        params: iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr: learning rate (default: 1e-3)
+        betas: coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps: term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay: weight decay (L2 penalty) (default: 0)
+        clamp_value: clamp weight_norm in (0,clamp_value) (default: 10)
+            set to a high value to avoid it (e.g 10e3)
+        bias_correction: debias statistics by (1 - beta**step) (default: False)
+        min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized
+        reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory.
+            If enabled, one must ensure that .zero_grad() is called after each optimizer step.
+        update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements.
+    """
+
+    def __init__(
+        self,
+        params: Params,
+        lr: float = 1e-3,
+        betas: Betas2 = (0.9, 0.999),
+        eps: float = 1e-6,
+        weight_decay: float = 0,
+        clamp_value: float = 10,
+        bias_correction: bool = False,
+        min_8bit_size: int = 65536,
+        reuse_grad_buffers: bool = False,
+        update_chunk_size: int = 2 ** 24,
+        max_grad_norm: Optional[float] = None,
+    ) -> None:
+        if lr <= 0.0:
+            raise ValueError('Invalid learning rate: {}'.format(lr))
+        if eps < 0.0:
+            raise ValueError('Invalid epsilon value: {}'.format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError(
+                'Invalid beta parameter at index 0: {}'.format(betas[0])
+            )
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError(
+                'Invalid beta parameter at index 1: {}'.format(betas[1])
+            )
+        if weight_decay < 0:
+            raise ValueError(
+                'Invalid weight_decay value: {}'.format(weight_decay)
+            )
+        if clamp_value < 0.0:
+            raise ValueError('Invalid clamp value: {}'.format(clamp_value))
+
+        self.clamp_value = clamp_value
+        self.bias_correction = bias_correction
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.update_chunk_size = update_chunk_size
+        self.max_grad_norm = max_grad_norm
+
+        super(CPULAMB8Bit, self).__init__(
+            'cpu-lamb', params, lr, betas, eps, weight_decay, optim_bits=8, min_8bit_size=min_8bit_size, args=None,
+            percentile_clipping=100, block_wise=4096, max_unorm=0)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        if self.max_grad_norm is not None:
+            iter_params = (param for group in self.param_groups for param in group['params'])
+            torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
+        return super().step(closure=closure)
+
+    @torch.no_grad()
+    def init_state(self, group, p, gindex, pindex):
+        config = self.get_config(gindex, pindex, group)
+        assert config['percentile_clipping'] == 100, "percentile clipping is not implemented on CPU"
+        assert config['max_unorm'] == 0
+
+        if config['optim_bits'] == 32:
+            dtype = torch.float32
+        elif config['optim_bits'] == 8:
+            dtype = torch.uint8
+        else:
+            raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+
+        if p.numel() < config['min_8bit_size']: dtype = torch.float32
+
+        state = self.state[p]
+        state['step'] = 0
+
+        if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
+                                               device=p.device)
+            state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32,
+                                               device=p.device)
+        elif dtype == torch.uint8:
+            if state['step'] == 0:
+                if 'dynamic' not in self.name2qmap: self.fill_qmap()
+                self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
+                self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
+
+            n = p.numel()
+            blocks = (n - 1) // config['block_wise'] + 1
+
+            state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
+                                               device=p.device)
+            state['qmap1'] = self.name2qmap['dynamic']
+
+            state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8,
+                                               device=p.device)
+            state['qmap2'] = self.name2qmap['udynamic']
+
+            state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+            state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+
+    @torch.no_grad()
+    def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int):
+        state = self.state[p]
+        config = self.get_config(gindex, pindex, group)
+
+        p_cpu, grad_cpu = p.cpu(), p.grad.cpu()
+        # this is a no-op if parameters are already on CPU
+
+        step = state['step'] = state['step'] + 1
+        beta1, beta2 = group['betas']
+
+        param_delta = self._update_moments_and_compute_delta(
+            state, config, p_cpu, grad_cpu, beta1, beta2, group['eps'], group['weight_decay']
+        )
+        del grad_cpu  # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers
+
+        step_norm = torch.norm(param_delta)
+        weight_norm = p_cpu.norm().clamp(0, self.clamp_value)
+
+        trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0
+        state['weight_norm'], state['step_norm'], state['trust_ratio'] = weight_norm, step_norm, trust_ratio
+
+        # Apply bias to lr to avoid broadcast.
+        bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1
+        step_size = group['lr'] * bias_correction
+        p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio)
+
+    def _update_moments_and_compute_delta(
+            self, state: Dict, config: Dict,
+            p_cpu: torch.Tensor, grad_cpu: torch.Tensor,
+            beta1: float, beta2: float, eps: float, weight_decay: float
+    ) -> torch.Tensor:
+        step, block_size, chunk_size = state['step'], config['block_wise'], self.update_chunk_size
+
+        if state['state1'].dtype != torch.uint8:
+            # not quantized: update normally
+            exp_avg, exp_avg_sq = state['state1'], state['state2']
+            exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
+            exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
+
+            sqrt_out = grad_cpu if self.reuse_grad_buffers else None
+            _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps)
+            param_delta = torch.div(exp_avg, _denominator, out=_denominator)
+            if weight_decay != 0:
+                param_delta.add_(p_cpu, alpha=weight_decay)
+            return param_delta
+        elif p_cpu.numel() <= chunk_size:
+            # quantized tensor within chunk size
+            exp_avg = dequantize_blockwise(
+                state['state1'], (state['absmax1'], state['qmap1']), blocksize=block_size
+            )
+            exp_avg_sq = dequantize_blockwise(
+                state['state2'], (state['absmax2'], state['qmap2']), blocksize=block_size
+            )
+
+            exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
+            exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
+
+            quantize_blockwise(exp_avg, state['qmap1'], state['absmax1'], out=state['state1'])
+            quantize_blockwise(exp_avg_sq, state['qmap2'], state['absmax2'], out=state['state2'])
+            # note: quantize_blockwise also modifies qmap and absmax in-place
+
+            param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps))
+            # note: this changes statistics in-place, but it's okay b/c we saved quantized version
+
+            if weight_decay != 0:
+                param_delta.add_(p_cpu, alpha=weight_decay)
+            return param_delta
+
+        else:
+            # very large quantized tensor, compute updates in chunks to save RAM
+            flat_p, flat_grad, flat_state1, flat_state2 = (
+                tensor.view(-1) for tensor in (p_cpu, grad_cpu, state['state1'], state['state2'])
+            )
+            output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad)
+
+            for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)):
+                chunk = slice(chunk_start, chunk_start + chunk_size)
+                chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size)
+
+                chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk]
+                chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk]
+                chunk_absmax1, chunk_absmax2 = state['absmax1'][chunk_blocks], state['absmax2'][chunk_blocks]
+                if chunk_state1.storage_offset() != 0:
+                    chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map(
+                        torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2)
+                    )  # clone chunks to ensure that tensors do not have offsets
+
+                exp_avg_chunk = dequantize_blockwise(
+                    chunk_state1, (chunk_absmax1, state['qmap1']), blocksize=block_size
+                )
+                exp_avg_sq_chunk = dequantize_blockwise(
+                    chunk_state2, (chunk_absmax2, state['qmap2']), blocksize=block_size
+                )
+
+                exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1)
+                exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2)
+
+                # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu
+                del chunk_grad
+
+                flat_state1[chunk], (state['absmax1'][chunk_blocks], state['qmap1']) = quantize_blockwise(
+                    exp_avg_chunk, state['qmap1'], chunk_absmax1, out=chunk_state1
+                )
+                flat_state2[chunk], (state['absmax2'][chunk_blocks], state['qmap2']) = quantize_blockwise(
+                    exp_avg_sq_chunk, state['qmap2'], chunk_absmax2, out=chunk_state2
+                )
+                # note: we need to explicitly assign new quantized tensors because of cloning earlier
+
+                torch.div(exp_avg_chunk, exp_avg_sq_chunk.sqrt_().add_(eps), out=output_buffer[chunk])
+                # note: this changes statistics in-place, but it's okay b/c we saved quantized version
+
+                if weight_decay != 0:
+                    output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay)
+
+            param_delta = output_buffer.view_as(grad_cpu)
+
+            return param_delta

+ 2 - 1
requirements.txt

@@ -1,5 +1,6 @@
-git+git://github.com/learning-at-home/hivemind@ed42040
+git+git://github.com/learning-at-home/hivemind@025e095
 git+git://github.com/learning-at-home/dalle-pytorch@weight-sharing
+bitsandbytes-cuda111>=0.26.0
 transformers>=4.9.2
 tokenizers>=0.10.2
 datasets>=1.11.0

+ 8 - 7
run_aux_peer.py

@@ -61,17 +61,17 @@ class CheckpointHandler:
         logger.info("Saving model")
         torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
         logger.info("Saving optimizer")
-        torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
+        torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
-        self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_step}, loss {current_loss:.3f}")
+        self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
         logger.info("Finished uploading to Model Hub")
 
 
 def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
     while True:
         time.sleep(peer_args.assist_refresh)
-        task.collaborative_optimizer.step_aux()
+        task.collaborative_optimizer.step()
 
 
 if __name__ == "__main__":
@@ -89,10 +89,11 @@ if __name__ == "__main__":
         checkpoint_handler = CheckpointHandler(task, peer_args)
 
     if peer_args.assist_in_averaging:
-        assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
-        averaging_thread = threading.Thread(
-            name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
-        averaging_thread.start()
+        # assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
+        # averaging_thread = threading.Thread(
+        #     name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
+        # averaging_thread.start()
+        raise NotImplementedError('aux peers with hivemind.optim.experimental are not supported yet')
 
     while True:
         metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)

+ 1 - 1
run_trainer_tpu.py

@@ -80,7 +80,7 @@ def main():
         loss, num_accumulated = tpu_manager.step()
         time_delta = time.perf_counter() - start_time
         logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.")
-        wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.scheduler.get_lr()[0]})
+        wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.state_averager.scheduler.get_lr()[0]})
 
         with torch.no_grad():
             for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()):

+ 16 - 13
task.py

@@ -16,7 +16,7 @@ import utils
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
 from data import make_dataset
 from huggingface_auth import authorize_with_huggingface
-from lib.training.clipped_lamb import LambWithGradientClipping
+from lib.training.lamb_8bit import CPULAMB8Bit
 
 
 logger = hivemind.get_logger(__name__)
@@ -115,20 +115,23 @@ class TrainingTask:
     @property
     def collaborative_optimizer(self):
         if self._collaborative_optimizer is None:
-            opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
+            params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
             averaging_compression = SizeAdaptiveCompression(
                 threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
-            state_compression = hivemind.Float16Compression()
-            self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
-                dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
+            self._collaborative_optimizer = hivemind.Optimizer(
+                dht=self.dht, run_id=self.peer_args.experiment_prefix,
+                params=params, optimizer=opt, scheduler=scheduler,
+                offload_optimizer=True,
+                delay_grad_averaging=False, delay_optimizer_step=True,
                 batch_size_per_step=self.trainer_args.batch_size_per_step,
-                compression=averaging_compression, state_compression=state_compression,
-                client_mode=self.peer_args.client_mode, verbose=True, start=True, **asdict(self.collab_args))
+                grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
+                client_mode=self.peer_args.client_mode, verbose=True,
+                **asdict(self.collab_args))
         return self._collaborative_optimizer
 
     def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
         no_decay = ["bias", "LayerNorm.weight"]
-        optimizer_grouped_parameters = [
+        params = [
             {
                 "params": [p for n, p in self.model.named_parameters()
                            if not any(nd in n for nd in no_decay) and p.requires_grad],
@@ -141,22 +144,22 @@ class TrainingTask:
             },
         ]
 
-        opt = LambWithGradientClipping(
-            optimizer_grouped_parameters,
+        opt = lambda params: CPULAMB8Bit(
+            params,
             lr=training_args.learning_rate,
             betas=(training_args.adam_beta1, training_args.adam_beta2),
             eps=training_args.adam_epsilon,
             weight_decay=training_args.weight_decay,
             max_grad_norm=training_args.max_grad_norm,
             clamp_value=training_args.clamp_value,
-            debias=True,
+            reuse_grad_buffers=True,
         )
 
-        scheduler = get_linear_schedule_with_warmup(
+        scheduler = lambda opt: get_linear_schedule_with_warmup(
             opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
         )
 
-        return opt, scheduler
+        return params, opt, scheduler
 
     @property
     def training_dataset(self):