Ver código fonte

Increase compression

Max Ryabinin 3 anos atrás
pai
commit
d800ff438e
1 arquivos alterados com 9 adições e 5 exclusões
  1. 9 5
      hivemind/moe/server/__init__.py

+ 9 - 5
hivemind/moe/server/__init__.py

@@ -4,7 +4,6 @@ import multiprocessing as mp
 import multiprocessing.synchronize
 import threading
 from contextlib import contextmanager
-from functools import partial
 from pathlib import Path
 from typing import Dict, List, Optional, Tuple
 
@@ -12,7 +11,7 @@ import torch
 from multiaddr import Multiaddr
 
 import hivemind
-from hivemind.compression import BASE_COMPRESSION_TYPES
+from hivemind.compression import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from hivemind.dht import DHT
 from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
 from hivemind.moe.server.connection_handler import ConnectionHandler
@@ -27,7 +26,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
-from hivemind.optim import CollaborativeOptimizer, OffloadOptimizer, LambWithGradientClipping
+from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
@@ -279,6 +278,11 @@ class Server(threading.Thread):
 
             expert.to(device)
 
+            averaging_compression = SizeAdaptiveCompression(
+                threshold=2 ** 16 + 1, less=Float16Compression(),
+                greater_equal=Uniform8BitQuantization()
+            )
+
             if use_averaging:
                 assert averaging_target_batch_size is not None
                 assert averaging_target_group_size is not None
@@ -288,8 +292,8 @@ class Server(threading.Thread):
                     dht=dht,
                     prefix=expert_uid.split(UID_DELIMITER)[0],
                     scheduler=scheduler,
-                    compression=BASE_COMPRESSION_TYPES[averaging_compression],
-                    state_compression=BASE_COMPRESSION_TYPES[averaging_compression],
+                    compression=averaging_compression,
+                    state_compression=Float16Compression(),
                     target_batch_size=averaging_target_batch_size,
                     target_group_size=averaging_target_group_size,
                     min_refresh_period=averaging_min_refresh_period,