Max Ryabinin 3 лет назад
Родитель
Сommit
836192eadc

+ 13 - 5
benchmarks/benchmark_tensor_compression.py

@@ -13,13 +13,14 @@ logger = get_logger(__name__)
 
 
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
     t = time.time()
     t = time.time()
-    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
-    return time.time() - t
+    serialized = serialize_torch_tensor(tensor, compression_type)
+    result = deserialize_torch_tensor(serialized)
+    return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
-    parser.add_argument("--size", type=int, default=10000000, required=False)
+    parser.add_argument("--size", type=int, default=10_000_000, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
 
 
@@ -30,7 +31,14 @@ if __name__ == "__main__":
 
 
     for name, compression_type in CompressionType.items():
     for name, compression_type in CompressionType.items():
         tm = 0
         tm = 0
+        distortion = 0
+        bytes = 0
         for i in range(args.num_iters):
         for i in range(args.num_iters):
-            tm += benchmark_compression(X, compression_type)
+            iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
+            tm += iter_time
+            distortion += iter_distortion
+            bytes += size
         tm /= args.num_iters
         tm /= args.num_iters
-        logger.info(f"Compression type: {name}, time: {tm}")
+        distortion /= args.num_iters
+        bytes /= args.num_iters
+        logger.info(f"Compression type: {name}, time: {tm:.5f}, distortion: {distortion:.5f}, size: {int(bytes):d}")

+ 2 - 2
hivemind/compression/__init__.py

@@ -10,18 +10,18 @@ import torch
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization, BlockwiseQuantization
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 
 
-
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
     NONE=NoCompression(),
     FLOAT16=Float16Compression(),
     FLOAT16=Float16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
+    BLOCKWISE_8BIT=BlockwiseQuantization(),
 )
 )
 
 
 for key in runtime_pb2.CompressionType.keys():
 for key in runtime_pb2.CompressionType.keys():

+ 41 - 0
hivemind/compression/quantization.py

@@ -6,6 +6,7 @@ from typing import Tuple
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
+from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
 
 
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
@@ -77,6 +78,46 @@ class Quantile8BitQuantization(Quantization):
         return quantized.numpy().astype(np.uint8), codebook.numpy()
         return quantized.numpy().astype(np.uint8), codebook.numpy()
 
 
 
 
+class BlockwiseQuantization(Quantization):
+    compression_type = runtime_pb2.BLOCKWISE_8BIT
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[
+        np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+        return quantize_blockwise(tensor)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+
+        serialized_data = (
+            np.int64(len(absmax)).tobytes(), np.int64(len(codebook)).tobytes(),
+            absmax.numpy().tobytes(),
+            codebook.numpy().tobytes(),
+            quantized.numpy().tobytes()
+        )
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join(serialized_data),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
+        absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size,
+                                 dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes,
+                                  dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.uint8).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(codebook)
+        absmax = torch.as_tensor(absmax)
+        return dequantize_blockwise(quantized, (absmax, codebook))
+
+
 def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
 def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
     """Return the average value in each bucket"""
     """Return the average value in each bucket"""
     bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
     bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())

+ 5 - 2
hivemind/moe/client/balancer.py

@@ -3,6 +3,7 @@ import random
 import threading
 import threading
 from contextlib import contextmanager
 from contextlib import contextmanager
 from typing import Dict, List, Tuple
 from typing import Dict, List, Tuple
+import time
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert
 from hivemind.moe.client.expert import RemoteExpert
@@ -15,7 +16,8 @@ logger = get_logger(__name__)
 
 
 class ExpertBalancer:
 class ExpertBalancer:
     def __init__(
     def __init__(
-        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0, **kwargs
+        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+        sleep_timeout: float = 5.0, **kwargs
     ):
     ):
         self.dht, self.key = dht, key
         self.dht, self.key = dht, key
         self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
         self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
@@ -29,6 +31,7 @@ class ExpertBalancer:
         self.is_alive.set()
         self.is_alive.set()
         self.update_trigger, self.update_finished = threading.Event(), threading.Event()
         self.update_trigger, self.update_finished = threading.Event(), threading.Event()
         self.update_period, self.last_update = update_period, get_dht_time()
         self.update_period, self.last_update = update_period, get_dht_time()
+        self.sleep_timeout = sleep_timeout
         self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
         self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
         self.update_thread.start()
         self.update_thread.start()
 
 
@@ -62,7 +65,7 @@ class ExpertBalancer:
                 )
                 )
             if len(self.queue) == 0:
             if len(self.queue) == 0:
                 logger.warning("Update routine finished, but still no experts available.")
                 logger.warning("Update routine finished, but still no experts available.")
-                time.sleep()
+                time.sleep(self.sleep_timeout)
 
 
             self.last_update = get_dht_time()
             self.last_update = get_dht_time()
             self.update_finished.set()
             self.update_finished.set()

+ 29 - 4
hivemind/moe/server/__init__.py

@@ -27,7 +27,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
     schedule_name_to_scheduler,
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
-from hivemind.optim import CollaborativeOptimizer
+from hivemind.optim import CollaborativeOptimizer, OffloadOptimizer, LambWithGradientClipping
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
 
@@ -108,7 +108,7 @@ class Server(threading.Thread):
         clip_grad_norm=None,
         clip_grad_norm=None,
         num_handlers=None,
         num_handlers=None,
         min_batch_size=1,
         min_batch_size=1,
-        max_batch_size=4096,
+        max_batch_size=1,
         use_averaging: bool = False,
         use_averaging: bool = False,
         averaging_target_batch_size: Optional[int] = None,
         averaging_target_batch_size: Optional[int] = None,
         averaging_target_group_size: Optional[int] = None,
         averaging_target_group_size: Optional[int] = None,
@@ -225,7 +225,6 @@ class Server(threading.Thread):
 
 
         num_experts = len(expert_uids)
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
 
 
         sample_input = name_to_input[expert_cls](3, hidden_dim)
         sample_input = name_to_input[expert_cls](3, hidden_dim)
         if isinstance(sample_input, tuple):
         if isinstance(sample_input, tuple):
@@ -240,12 +239,37 @@ class Server(threading.Thread):
         experts = {}
         experts = {}
         for expert_uid in expert_uids:
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
             expert = name_to_block[expert_cls](hidden_dim)
+
+            no_decay = ["bias", "LayerNorm.weight"]
+            optimizer_grouped_parameters = [
+                {
+                    "params": [p for n, p in expert.named_parameters() if not any(nd in n for nd in no_decay)],
+                    "weight_decay": 0.01,
+                },
+                {
+                    "params": [p for n, p in expert.named_parameters() if any(nd in n for nd in no_decay)],
+                    "weight_decay": 0.0,
+                },
+            ]
+
+            optim = OffloadOptimizer(
+                optimizer_grouped_parameters,
+                optim_cls=LambWithGradientClipping,
+                lr=0.00176,
+                betas=(0.9, 0.999),
+                eps=1e-6,
+                weight_decay=0.01,
+                max_grad_norm=1,
+                clamp_value=10000.0,
+                debias=True,
+            )
+
             expert.to(device)
             expert.to(device)
 
 
-            optim = optim_cls(expert.parameters())
             if use_averaging:
             if use_averaging:
                 assert averaging_target_batch_size is not None
                 assert averaging_target_batch_size is not None
                 assert averaging_target_group_size is not None
                 assert averaging_target_group_size is not None
+
                 optim = CollaborativeOptimizer(
                 optim = CollaborativeOptimizer(
                     optim,
                     optim,
                     dht=dht,
                     dht=dht,
@@ -264,6 +288,7 @@ class Server(threading.Thread):
                     verbose=True,
                     verbose=True,
                     start=True,
                     start=True,
                 )
                 )
+                optim.load_state_from_peers()
 
 
             experts[expert_uid] = ExpertBackend(
             experts[expert_uid] = ExpertBackend(
                 name=expert_uid,
                 name=expert_uid,

+ 2 - 0
hivemind/moe/server/expert_backend.py

@@ -96,6 +96,7 @@ class ExpertBackend:
         self.update_count = 0
         self.update_count = 0
         self.examples_processed = 0
         self.examples_processed = 0
 
 
+    @torch.cuda.amp.autocast()
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -121,6 +122,7 @@ class ExpertBackend:
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
         return tuple(nested_flatten(outputs))
 
 
+    @torch.cuda.amp.autocast()
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually

+ 18 - 7
hivemind/moe/server/layers/albert.py

@@ -212,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
             self.register_buffer("cos", cos)
             self.register_buffer("cos", cos)
             self.register_buffer("sin", sin)
             self.register_buffer("sin", sin)
 
 
-        return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
+        return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
 
 
 
 
 @torch.no_grad()
 @torch.no_grad()
@@ -311,7 +311,7 @@ class SimpleAttentionCore(nn.Module):
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
 
 
         query_length, key_length = query.size(-2), key.size(-2)
         query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+        causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
 
 
         if attention_mask is not None:
         if attention_mask is not None:
@@ -413,7 +413,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
 
         if self.position_embeddings is not None:
         if self.position_embeddings is not None:
             if position_ids is None:
             if position_ids is None:
-                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+                position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
             position_embeddings = self.position_embeddings(position_ids)
             position_embeddings = self.position_embeddings(position_ids)
             embeddings += position_embeddings
             embeddings += position_embeddings
 
 
@@ -583,8 +583,10 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
 
 
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 
 
+SEQUENCE_LENGTH = 2048
+
 head_sample_input = lambda batch_size, hid_dim: (
 head_sample_input = lambda batch_size, hid_dim: (
-    torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
+    torch.randint(low=0, high=1000, size=(batch_size, SEQUENCE_LENGTH), dtype=torch.long),
 )
 )
 
 
 
 
@@ -594,7 +596,10 @@ class HeadExpert(nn.Module):
         super().__init__()
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
 
         self.encoder = LeanAlbertTransformer(config)
         self.encoder = LeanAlbertTransformer(config)
         self.embeddings = LeanAlbertEmbeddings(config)
         self.embeddings = LeanAlbertEmbeddings(config)
@@ -606,7 +611,7 @@ class HeadExpert(nn.Module):
         return encoder_outputs
         return encoder_outputs
 
 
 
 
-body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, 512, hid_dim)),)
+body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
 
 
 
 
 @register_expert_class("lm_body", body_sample_input)
 @register_expert_class("lm_body", body_sample_input)
@@ -615,7 +620,10 @@ class BodyExpert(nn.Module):
         super().__init__()
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
 
         self.config = config
         self.config = config
         self.albert_layer_groups = nn.ModuleList(
         self.albert_layer_groups = nn.ModuleList(
@@ -643,8 +651,8 @@ class BodyExpert(nn.Module):
 
 
 
 
 tail_sample_input = lambda batch_size, hid_dim: (
 tail_sample_input = lambda batch_size, hid_dim: (
-    torch.empty((batch_size, 512, hid_dim)),
-    torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
+    torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),
+    torch.randint(0, 1000, (batch_size, SEQUENCE_LENGTH), dtype=torch.long),
 )
 )
 
 
 
 
@@ -654,7 +662,10 @@ class TailExpert(nn.Module):
         super().__init__()
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
 
         self.config = config
         self.config = config
         self.albert_layer_groups = nn.ModuleList(
         self.albert_layer_groups = nn.ModuleList(

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
-from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase, OffloadOptimizer, LambWithGradientClipping
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 179 - 0
hivemind/optim/base.py

@@ -1,3 +1,6 @@
+import contextlib
+from typing import Dict, Iterable, Optional, Type, Union
+
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
@@ -34,3 +37,179 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
 
     def shutdown(self):
     def shutdown(self):
         raise NotImplementedError()
         raise NotImplementedError()
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    r"""
+    A wrapper for pytorch.optimizer that forwards all methods to the wrapped optimizer
+    """
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        object.__init__(self)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)
+
+
+class OffloadOptimizer(OptimizerWrapper):
+    r"""A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM)."""
+
+    def __init__(
+        self,
+        param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]],
+        optim_cls: Type[torch.optim.Optimizer],
+        *args,
+        full_sync: bool = True,
+        offload_device=torch.device("cpu"),
+        offload_dtype: Optional[torch.dtype] = None,
+        **kwargs,
+    ):
+        param_groups = list(param_groups)
+        if not isinstance(param_groups[0], dict):
+            param_groups = [{"params": param_groups}]
+        super().__init__(optim_cls(param_groups, *args, **kwargs))
+        self.full_sync = full_sync
+
+        with torch.no_grad():
+            self.offload_params_by_group = tuple(
+                [
+                    torch.nn.Parameter(
+                        torch.empty_like(param, device=offload_device, dtype=offload_dtype),
+                        requires_grad=param.requires_grad,
+                    )
+                    for param in group["params"]
+                ]
+                for group in param_groups
+            )
+
+            for group, offload_params in zip(param_groups, self.offload_params_by_group):
+                for param, offload_param in zip(group["params"], offload_params):
+                    offload_param.copy_(param, non_blocking=True)
+                    if offload_param.grad is None:
+                        offload_param.grad = torch.zeros_like(offload_param)
+                    if param.grad is not None:
+                        offload_param.grad.copy_(param.grad, non_blocking=True)
+
+    @contextlib.contextmanager
+    def _use_offloaded_params(
+        self, *, sync_params_before: bool, sync_grads_before: bool, sync_params_after: bool, sync_grads_after: bool
+    ):
+        assert len(self.param_groups) == len(self.offload_params_by_group)
+        original_params_per_group = [group["params"] for group in self.param_groups]
+        try:
+            with torch.no_grad():
+                for original_params, replacement_params in zip(
+                    original_params_per_group, self.offload_params_by_group
+                ):
+                    for original_param, replacement_param in zip(original_params, replacement_params):
+                        if sync_params_before:
+                            replacement_param.copy_(original_param, non_blocking=True)
+                        if sync_grads_before and original_param.grad is not None:
+                            replacement_param.grad.copy_(original_param.grad, non_blocking=True)
+
+            for group, replacement_params in zip(self.param_groups, self.offload_params_by_group):
+                group["params"] = replacement_params
+            yield self.param_groups
+        finally:
+            for group, original_params in zip(self.param_groups, original_params_per_group):
+                group["params"] = original_params
+
+            with torch.no_grad():
+                for original_params, replacement_params in zip(
+                    original_params_per_group, self.offload_params_by_group
+                ):
+                    for original_param, replacement_param in zip(original_params, replacement_params):
+                        if sync_params_after:
+                            original_param.copy_(replacement_param, non_blocking=True)
+                        if sync_grads_after and original_param.grad is not None:
+                            original_param.grad.copy_(replacement_param.grad)
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.")
+
+    def step(self, closure=None, *args, **kwargs):
+        assert closure is None, "closure not supported in cpu offload mode"
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=True,
+            sync_params_after=True,
+            sync_grads_after=self.full_sync,
+        ):
+            return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, set_to_none: bool = False, *args, **kwargs):
+        if not self.full_sync:
+            torch.optim.Optimizer.zero_grad(self, set_to_none)
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=self.full_sync,
+            sync_params_after=self.full_sync,
+            sync_grads_after=self.full_sync,
+        ):
+            return super().zero_grad(*args, set_to_none=False, **kwargs)
+
+    def state_dict(self):
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=self.full_sync,
+            sync_params_after=False,
+            sync_grads_after=False,
+        ):
+            return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        with self._use_offloaded_params(
+            sync_params_before=False, sync_grads_before=False, sync_params_after=True, sync_grads_after=self.full_sync
+        ):
+            return self.optim.load_state_dict(state_dict)
+
+
+import torch
+from torch_optimizer import Lamb
+
+
+class LambWithGradientClipping(Lamb):
+    """A version of LAMB that clips gradients based on their norm."""
+
+    def __init__(self, *args, max_grad_norm: float, **kwargs):
+        self.max_grad_norm = max_grad_norm
+        super().__init__(*args, **kwargs)
+
+    def step(self, *args, **kwargs):
+        iter_params = (param for group in self.param_groups for param in group["params"])
+        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
+        return super().step(*args, **kwargs)

+ 1 - 0
hivemind/proto/runtime.proto

@@ -32,6 +32,7 @@ enum CompressionType{
   FLOAT16 = 2;
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
   QUANTILE_8BIT = 3;
   UNIFORM_8BIT = 4;
   UNIFORM_8BIT = 4;
+  BLOCKWISE_8BIT = 5;
 }
 }
 
 
 message Tensor {
 message Tensor {

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -9,7 +9,7 @@ import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 
 
-DUMMY_BATCH_SIZE = 3  # used for dummy runs only
+DUMMY_BATCH_SIZE = 1  # used for dummy runs only
 
 
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
 
 

+ 2 - 0
tests/test_compression.py

@@ -37,6 +37,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     assert error.square().mean() < beta
     assert error.square().mean() < beta
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
     assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
+    assert error.square().mean() < beta
 
 
     zeros = torch.zeros(5, 5)
     zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():
     for compression_type in CompressionType.values():