|
@@ -4,7 +4,6 @@ import multiprocessing as mp
|
|
|
import multiprocessing.synchronize
|
|
|
import threading
|
|
|
from contextlib import contextmanager
|
|
|
-from functools import partial
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
@@ -12,7 +11,7 @@ import torch
|
|
|
from multiaddr import Multiaddr
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.compression import BASE_COMPRESSION_TYPES
|
|
|
+from hivemind.compression import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.moe.server.checkpoints import CheckpointSaver, is_directory, load_experts
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
@@ -27,7 +26,7 @@ from hivemind.moe.server.layers import (
|
|
|
schedule_name_to_scheduler,
|
|
|
)
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
-from hivemind.optim import CollaborativeOptimizer, OffloadOptimizer, LambWithGradientClipping
|
|
|
+from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
|
|
|
|
|
@@ -279,6 +278,11 @@ class Server(threading.Thread):
|
|
|
|
|
|
expert.to(device)
|
|
|
|
|
|
+ averaging_compression = SizeAdaptiveCompression(
|
|
|
+ threshold=2 ** 16 + 1, less=Float16Compression(),
|
|
|
+ greater_equal=Uniform8BitQuantization()
|
|
|
+ )
|
|
|
+
|
|
|
if use_averaging:
|
|
|
assert averaging_target_batch_size is not None
|
|
|
assert averaging_target_group_size is not None
|
|
@@ -288,8 +292,8 @@ class Server(threading.Thread):
|
|
|
dht=dht,
|
|
|
prefix=expert_uid.split(UID_DELIMITER)[0],
|
|
|
scheduler=scheduler,
|
|
|
- compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
|
- state_compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
|
+ compression=averaging_compression,
|
|
|
+ state_compression=Float16Compression(),
|
|
|
target_batch_size=averaging_target_batch_size,
|
|
|
target_group_size=averaging_target_group_size,
|
|
|
min_refresh_period=averaging_min_refresh_period,
|