Max Ryabinin 3 ani în urmă
părinte
comite
836192eadc

+ 13 - 5
benchmarks/benchmark_tensor_compression.py

@@ -13,13 +13,14 @@ logger = get_logger(__name__)
 
 def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> float:
     t = time.time()
-    deserialize_torch_tensor(serialize_torch_tensor(tensor, compression_type))
-    return time.time() - t
+    serialized = serialize_torch_tensor(tensor, compression_type)
+    result = deserialize_torch_tensor(serialized)
+    return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize()
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--size", type=int, default=10000000, required=False)
+    parser.add_argument("--size", type=int, default=10_000_000, required=False)
     parser.add_argument("--seed", type=int, default=7348, required=False)
     parser.add_argument("--num_iters", type=int, default=30, required=False)
 
@@ -30,7 +31,14 @@ if __name__ == "__main__":
 
     for name, compression_type in CompressionType.items():
         tm = 0
+        distortion = 0
+        bytes = 0
         for i in range(args.num_iters):
-            tm += benchmark_compression(X, compression_type)
+            iter_time, iter_distortion, size = benchmark_compression(X, compression_type)
+            tm += iter_time
+            distortion += iter_distortion
+            bytes += size
         tm /= args.num_iters
-        logger.info(f"Compression type: {name}, time: {tm}")
+        distortion /= args.num_iters
+        bytes /= args.num_iters
+        logger.info(f"Compression type: {name}, time: {tm:.5f}, distortion: {distortion:.5f}, size: {int(bytes):d}")

+ 2 - 2
hivemind/compression/__init__.py

@@ -10,18 +10,18 @@ import torch
 from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression
 from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole
 from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression
-from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization
+from hivemind.compression.quantization import Quantile8BitQuantization, Uniform8BitQuantization, BlockwiseQuantization
 from hivemind.proto import runtime_pb2
 
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 
-
 BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict(
     NONE=NoCompression(),
     FLOAT16=Float16Compression(),
     MEANSTD_16BIT=ScaledFloat16Compression(),
     QUANTILE_8BIT=Quantile8BitQuantization(),
     UNIFORM_8BIT=Uniform8BitQuantization(),
+    BLOCKWISE_8BIT=BlockwiseQuantization(),
 )
 
 for key in runtime_pb2.CompressionType.keys():

+ 41 - 0
hivemind/compression/quantization.py

@@ -6,6 +6,7 @@ from typing import Tuple
 
 import numpy as np
 import torch
+from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
 
 from hivemind.compression.base import CompressionBase, CompressionInfo
 from hivemind.proto import runtime_pb2
@@ -77,6 +78,46 @@ class Quantile8BitQuantization(Quantization):
         return quantized.numpy().astype(np.uint8), codebook.numpy()
 
 
+class BlockwiseQuantization(Quantization):
+    compression_type = runtime_pb2.BLOCKWISE_8BIT
+    codebook_dtype, indices_dtype = np.float32, np.uint8
+
+    def quantize(self, tensor: torch.Tensor, allow_inplace: bool = False) -> Tuple[
+        np.ndarray, Tuple[np.ndarray, np.ndarray]]:
+        return quantize_blockwise(tensor)
+
+    def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
+        quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
+
+        serialized_data = (
+            np.int64(len(absmax)).tobytes(), np.int64(len(codebook)).tobytes(),
+            absmax.numpy().tobytes(),
+            codebook.numpy().tobytes(),
+            quantized.numpy().tobytes()
+        )
+
+        return runtime_pb2.Tensor(
+            compression=self.compression_type,
+            buffer=b"".join(serialized_data),
+            size=tensor.shape,
+            dtype=tensor.numpy().dtype.name,
+            requires_grad=tensor.requires_grad,
+        )
+
+    def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
+        absmax_size = int(np.frombuffer(serialized_tensor.buffer, count=1, dtype=np.int64))
+        codebook_size = int(np.frombuffer(serialized_tensor.buffer, offset=8, count=1, dtype=np.int64))
+        absmax = np.frombuffer(serialized_tensor.buffer, offset=16, count=absmax_size, dtype=self.codebook_dtype)
+        codebook = np.frombuffer(serialized_tensor.buffer, offset=16 + absmax.nbytes, count=codebook_size,
+                                 dtype=self.codebook_dtype)
+        quantized = np.frombuffer(serialized_tensor.buffer, offset=16 + absmax.nbytes + codebook.nbytes,
+                                  dtype=self.indices_dtype)
+        quantized = torch.as_tensor(quantized, dtype=torch.uint8).reshape(tuple(serialized_tensor.size))
+        codebook = torch.as_tensor(codebook)
+        absmax = torch.as_tensor(absmax)
+        return dequantize_blockwise(quantized, (absmax, codebook))
+
+
 def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
     """Return the average value in each bucket"""
     bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())

+ 5 - 2
hivemind/moe/client/balancer.py

@@ -3,6 +3,7 @@ import random
 import threading
 from contextlib import contextmanager
 from typing import Dict, List, Tuple
+import time
 
 from hivemind.dht import DHT
 from hivemind.moe.client.expert import RemoteExpert
@@ -15,7 +16,8 @@ logger = get_logger(__name__)
 
 class ExpertBalancer:
     def __init__(
-        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0, **kwargs
+        self, dht: DHT, key: ExpertPrefix, update_period: float = 30.0, initial_throughput: float = 1.0,
+        sleep_timeout: float = 5.0, **kwargs
     ):
         self.dht, self.key = dht, key
         self.initial_throughput, self.ema_kwargs = initial_throughput, kwargs
@@ -29,6 +31,7 @@ class ExpertBalancer:
         self.is_alive.set()
         self.update_trigger, self.update_finished = threading.Event(), threading.Event()
         self.update_period, self.last_update = update_period, get_dht_time()
+        self.sleep_timeout = sleep_timeout
         self.update_thread = threading.Thread(target=self.update_experts_in_background, daemon=True)
         self.update_thread.start()
 
@@ -62,7 +65,7 @@ class ExpertBalancer:
                 )
             if len(self.queue) == 0:
                 logger.warning("Update routine finished, but still no experts available.")
-                time.sleep()
+                time.sleep(self.sleep_timeout)
 
             self.last_update = get_dht_time()
             self.update_finished.set()

+ 29 - 4
hivemind/moe/server/__init__.py

@@ -27,7 +27,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
-from hivemind.optim import CollaborativeOptimizer
+from hivemind.optim import CollaborativeOptimizer, OffloadOptimizer, LambWithGradientClipping
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
@@ -108,7 +108,7 @@ class Server(threading.Thread):
         clip_grad_norm=None,
         num_handlers=None,
         min_batch_size=1,
-        max_batch_size=4096,
+        max_batch_size=1,
         use_averaging: bool = False,
         averaging_target_batch_size: Optional[int] = None,
         averaging_target_group_size: Optional[int] = None,
@@ -225,7 +225,6 @@ class Server(threading.Thread):
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
-        optim_cls = optim_cls if optim_cls is not None else partial(torch.optim.SGD, lr=0.0)
 
         sample_input = name_to_input[expert_cls](3, hidden_dim)
         if isinstance(sample_input, tuple):
@@ -240,12 +239,37 @@ class Server(threading.Thread):
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
+
+            no_decay = ["bias", "LayerNorm.weight"]
+            optimizer_grouped_parameters = [
+                {
+                    "params": [p for n, p in expert.named_parameters() if not any(nd in n for nd in no_decay)],
+                    "weight_decay": 0.01,
+                },
+                {
+                    "params": [p for n, p in expert.named_parameters() if any(nd in n for nd in no_decay)],
+                    "weight_decay": 0.0,
+                },
+            ]
+
+            optim = OffloadOptimizer(
+                optimizer_grouped_parameters,
+                optim_cls=LambWithGradientClipping,
+                lr=0.00176,
+                betas=(0.9, 0.999),
+                eps=1e-6,
+                weight_decay=0.01,
+                max_grad_norm=1,
+                clamp_value=10000.0,
+                debias=True,
+            )
+
             expert.to(device)
 
-            optim = optim_cls(expert.parameters())
             if use_averaging:
                 assert averaging_target_batch_size is not None
                 assert averaging_target_group_size is not None
+
                 optim = CollaborativeOptimizer(
                     optim,
                     dht=dht,
@@ -264,6 +288,7 @@ class Server(threading.Thread):
                     verbose=True,
                     start=True,
                 )
+                optim.load_state_from_peers()
 
             experts[expert_uid] = ExpertBackend(
                 name=expert_uid,

+ 2 - 0
hivemind/moe/server/expert_backend.py

@@ -96,6 +96,7 @@ class ExpertBackend:
         self.update_count = 0
         self.examples_processed = 0
 
+    @torch.cuda.amp.autocast()
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -121,6 +122,7 @@ class ExpertBackend:
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
 
+    @torch.cuda.amp.autocast()
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually

+ 18 - 7
hivemind/moe/server/layers/albert.py

@@ -212,7 +212,7 @@ class RotaryEmbeddings(nn.Module):
             self.register_buffer("cos", cos)
             self.register_buffer("sin", sin)
 
-        return rotate(x, cos[None, offset : seq_len + offset, None, :], sin[None, offset : seq_len + offset, None, :])
+        return rotate(x, cos[None, offset: seq_len + offset, None, :], sin[None, offset: seq_len + offset, None, :])
 
 
 @torch.no_grad()
@@ -311,7 +311,7 @@ class SimpleAttentionCore(nn.Module):
         attention_scores = attention_scores / math.sqrt(query.shape[-1])
 
         query_length, key_length = query.size(-2), key.size(-2)
-        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
+        causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool()
         attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores.dtype))
 
         if attention_mask is not None:
@@ -413,7 +413,7 @@ class LeanAlbertEmbeddings(nn.Module):
 
         if self.position_embeddings is not None:
             if position_ids is None:
-                position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+                position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
             position_embeddings = self.position_embeddings(position_ids)
             embeddings += position_embeddings
 
@@ -583,8 +583,10 @@ class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
 
 from hivemind.moe.server.layers.custom_experts import register_expert_class
 
+SEQUENCE_LENGTH = 2048
+
 head_sample_input = lambda batch_size, hid_dim: (
-    torch.randint(low=0, high=1000, size=(batch_size, 512), dtype=torch.long),
+    torch.randint(low=0, high=1000, size=(batch_size, SEQUENCE_LENGTH), dtype=torch.long),
 )
 
 
@@ -594,7 +596,10 @@ class HeadExpert(nn.Module):
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
         self.encoder = LeanAlbertTransformer(config)
         self.embeddings = LeanAlbertEmbeddings(config)
@@ -606,7 +611,7 @@ class HeadExpert(nn.Module):
         return encoder_outputs
 
 
-body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, 512, hid_dim)),)
+body_sample_input = lambda batch_size, hid_dim: (torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),)
 
 
 @register_expert_class("lm_body", body_sample_input)
@@ -615,7 +620,10 @@ class BodyExpert(nn.Module):
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
         self.config = config
         self.albert_layer_groups = nn.ModuleList(
@@ -643,8 +651,8 @@ class BodyExpert(nn.Module):
 
 
 tail_sample_input = lambda batch_size, hid_dim: (
-    torch.empty((batch_size, 512, hid_dim)),
-    torch.randint(0, 1000, (batch_size, 512), dtype=torch.long),
+    torch.empty((batch_size, SEQUENCE_LENGTH, hid_dim)),
+    torch.randint(0, 1000, (batch_size, SEQUENCE_LENGTH), dtype=torch.long),
 )
 
 
@@ -654,7 +662,10 @@ class TailExpert(nn.Module):
         super().__init__()
         config = LeanAlbertConfig.from_pretrained("albert-xxlarge-v2")
         config.hidden_size = hid_dim
+        config.intermediate_size = 4 * config.hidden_size
         config.num_hidden_layers = 12
+        config.vocab_size = 50304
+        config.max_position_embeddings = SEQUENCE_LENGTH
 
         self.config = config
         self.albert_layer_groups = nn.ModuleList(

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,4 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
-from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.base import DecentralizedOptimizerBase, OffloadOptimizer, LambWithGradientClipping
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 179 - 0
hivemind/optim/base.py

@@ -1,3 +1,6 @@
+import contextlib
+from typing import Dict, Iterable, Optional, Type, Union
+
 import torch
 
 from hivemind.dht import DHT
@@ -34,3 +37,179 @@ class DecentralizedOptimizerBase(torch.optim.Optimizer):
 
     def shutdown(self):
         raise NotImplementedError()
+
+
+class OptimizerWrapper(torch.optim.Optimizer):
+    r"""
+    A wrapper for pytorch.optimizer that forwards all methods to the wrapped optimizer
+    """
+
+    def __init__(self, optim: torch.optim.Optimizer):
+        object.__init__(self)
+        self.optim = optim
+
+    @property
+    def defaults(self):
+        return self.optim.defaults
+
+    @property
+    def state(self):
+        return self.optim.state
+
+    def __getstate__(self):
+        return self.optim.__getstate__()
+
+    def __setstate__(self, state):
+        self.optim.__setstate__(state)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({repr(self.optim)})"
+
+    def state_dict(self):
+        return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        return self.optim.load_state_dict(state_dict)
+
+    def step(self, *args, **kwargs):
+        return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.optim.zero_grad(*args, **kwargs)
+
+    @property
+    def param_groups(self):
+        return self.optim.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        return self.optim.add_param_group(param_group)
+
+
+class OffloadOptimizer(OptimizerWrapper):
+    r"""A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM)."""
+
+    def __init__(
+        self,
+        param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]],
+        optim_cls: Type[torch.optim.Optimizer],
+        *args,
+        full_sync: bool = True,
+        offload_device=torch.device("cpu"),
+        offload_dtype: Optional[torch.dtype] = None,
+        **kwargs,
+    ):
+        param_groups = list(param_groups)
+        if not isinstance(param_groups[0], dict):
+            param_groups = [{"params": param_groups}]
+        super().__init__(optim_cls(param_groups, *args, **kwargs))
+        self.full_sync = full_sync
+
+        with torch.no_grad():
+            self.offload_params_by_group = tuple(
+                [
+                    torch.nn.Parameter(
+                        torch.empty_like(param, device=offload_device, dtype=offload_dtype),
+                        requires_grad=param.requires_grad,
+                    )
+                    for param in group["params"]
+                ]
+                for group in param_groups
+            )
+
+            for group, offload_params in zip(param_groups, self.offload_params_by_group):
+                for param, offload_param in zip(group["params"], offload_params):
+                    offload_param.copy_(param, non_blocking=True)
+                    if offload_param.grad is None:
+                        offload_param.grad = torch.zeros_like(offload_param)
+                    if param.grad is not None:
+                        offload_param.grad.copy_(param.grad, non_blocking=True)
+
+    @contextlib.contextmanager
+    def _use_offloaded_params(
+        self, *, sync_params_before: bool, sync_grads_before: bool, sync_params_after: bool, sync_grads_after: bool
+    ):
+        assert len(self.param_groups) == len(self.offload_params_by_group)
+        original_params_per_group = [group["params"] for group in self.param_groups]
+        try:
+            with torch.no_grad():
+                for original_params, replacement_params in zip(
+                    original_params_per_group, self.offload_params_by_group
+                ):
+                    for original_param, replacement_param in zip(original_params, replacement_params):
+                        if sync_params_before:
+                            replacement_param.copy_(original_param, non_blocking=True)
+                        if sync_grads_before and original_param.grad is not None:
+                            replacement_param.grad.copy_(original_param.grad, non_blocking=True)
+
+            for group, replacement_params in zip(self.param_groups, self.offload_params_by_group):
+                group["params"] = replacement_params
+            yield self.param_groups
+        finally:
+            for group, original_params in zip(self.param_groups, original_params_per_group):
+                group["params"] = original_params
+
+            with torch.no_grad():
+                for original_params, replacement_params in zip(
+                    original_params_per_group, self.offload_params_by_group
+                ):
+                    for original_param, replacement_param in zip(original_params, replacement_params):
+                        if sync_params_after:
+                            original_param.copy_(replacement_param, non_blocking=True)
+                        if sync_grads_after and original_param.grad is not None:
+                            original_param.grad.copy_(replacement_param.grad)
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.")
+
+    def step(self, closure=None, *args, **kwargs):
+        assert closure is None, "closure not supported in cpu offload mode"
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=True,
+            sync_params_after=True,
+            sync_grads_after=self.full_sync,
+        ):
+            return self.optim.step(*args, **kwargs)
+
+    def zero_grad(self, set_to_none: bool = False, *args, **kwargs):
+        if not self.full_sync:
+            torch.optim.Optimizer.zero_grad(self, set_to_none)
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=self.full_sync,
+            sync_params_after=self.full_sync,
+            sync_grads_after=self.full_sync,
+        ):
+            return super().zero_grad(*args, set_to_none=False, **kwargs)
+
+    def state_dict(self):
+        with self._use_offloaded_params(
+            sync_params_before=self.full_sync,
+            sync_grads_before=self.full_sync,
+            sync_params_after=False,
+            sync_grads_after=False,
+        ):
+            return self.optim.state_dict()
+
+    def load_state_dict(self, state_dict: dict) -> None:
+        with self._use_offloaded_params(
+            sync_params_before=False, sync_grads_before=False, sync_params_after=True, sync_grads_after=self.full_sync
+        ):
+            return self.optim.load_state_dict(state_dict)
+
+
+import torch
+from torch_optimizer import Lamb
+
+
+class LambWithGradientClipping(Lamb):
+    """A version of LAMB that clips gradients based on their norm."""
+
+    def __init__(self, *args, max_grad_norm: float, **kwargs):
+        self.max_grad_norm = max_grad_norm
+        super().__init__(*args, **kwargs)
+
+    def step(self, *args, **kwargs):
+        iter_params = (param for group in self.param_groups for param in group["params"])
+        torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
+        return super().step(*args, **kwargs)

+ 1 - 0
hivemind/proto/runtime.proto

@@ -32,6 +32,7 @@ enum CompressionType{
   FLOAT16 = 2;
   QUANTILE_8BIT = 3;
   UNIFORM_8BIT = 4;
+  BLOCKWISE_8BIT = 5;
 }
 
 message Tensor {

+ 1 - 1
hivemind/utils/tensor_descr.py

@@ -9,7 +9,7 @@ import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
 
-DUMMY_BATCH_SIZE = 3  # used for dummy runs only
+DUMMY_BATCH_SIZE = 1  # used for dummy runs only
 
 warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning)
 

+ 2 - 0
tests/test_compression.py

@@ -37,6 +37,8 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     assert error.square().mean() < beta
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
+    error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.BLOCKWISE_8BIT)) - X
+    assert error.square().mean() < beta
 
     zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():