|
@@ -26,7 +26,7 @@ from hivemind.moe.server.layers import (
|
|
|
schedule_name_to_scheduler,
|
|
|
)
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
-from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping
|
|
|
+from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping, OffloadOptimizer
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
|
|
|
|
|
@@ -120,6 +120,7 @@ class Server(threading.Thread):
|
|
|
reuse_grad_buffers=True,
|
|
|
device=None,
|
|
|
fp16=False,
|
|
|
+ offload=False,
|
|
|
no_dht=False,
|
|
|
dht_port=None,
|
|
|
dht_listen_on=None,
|
|
@@ -252,8 +253,7 @@ class Server(threading.Thread):
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- optim = LambWithGradientClipping(
|
|
|
- optimizer_grouped_parameters,
|
|
|
+ optim_kwargs = dict(
|
|
|
lr=0.0035355339059327377,
|
|
|
betas=(0.9, 0.999),
|
|
|
eps=1e-6,
|
|
@@ -263,31 +263,31 @@ class Server(threading.Thread):
|
|
|
debias=True,
|
|
|
)
|
|
|
|
|
|
- scheduler = scheduler(optim, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)
|
|
|
+ if offload:
|
|
|
+ optim = OffloadOptimizer(
|
|
|
+ optimizer_grouped_parameters,
|
|
|
+ optim_cls=LambWithGradientClipping,
|
|
|
+ **optim_kwargs
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ optim = LambWithGradientClipping(
|
|
|
+ optimizer_grouped_parameters,
|
|
|
+ **optim_kwargs
|
|
|
+ )
|
|
|
|
|
|
- # optim = OffloadOptimizer(
|
|
|
- # optimizer_grouped_parameters,
|
|
|
- # optim_cls=LambWithGradientClipping,
|
|
|
- # lr=0.0035355339059327377,
|
|
|
- # betas=(0.9, 0.999),
|
|
|
- # eps=1e-6,
|
|
|
- # weight_decay=0.01,
|
|
|
- # max_grad_norm=1,
|
|
|
- # clamp_value=10000.0,
|
|
|
- # debias=True,
|
|
|
- # )
|
|
|
+ scheduler = scheduler(optim, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)
|
|
|
|
|
|
expert.to(device)
|
|
|
|
|
|
- averaging_compression = SizeAdaptiveCompression(
|
|
|
- threshold=2 ** 16 + 1, less=Float16Compression(),
|
|
|
- greater_equal=Uniform8BitQuantization()
|
|
|
- )
|
|
|
-
|
|
|
if use_averaging:
|
|
|
assert averaging_target_batch_size is not None
|
|
|
assert averaging_target_group_size is not None
|
|
|
|
|
|
+ averaging_compression = SizeAdaptiveCompression(
|
|
|
+ threshold=2 ** 16 + 1, less=Float16Compression(),
|
|
|
+ greater_equal=Uniform8BitQuantization()
|
|
|
+ )
|
|
|
+
|
|
|
optim = CollaborativeOptimizer(
|
|
|
optim,
|
|
|
dht=dht,
|
|
@@ -309,17 +309,29 @@ class Server(threading.Thread):
|
|
|
)
|
|
|
optim.load_state_from_peers()
|
|
|
|
|
|
- experts[expert_uid] = ExpertBackend(
|
|
|
- name=expert_uid,
|
|
|
- expert=expert,
|
|
|
- args_schema=args_schema,
|
|
|
- optimizer=optim,
|
|
|
- device=device,
|
|
|
- fp16=fp16,
|
|
|
- clip_grad_norm=clip_grad_norm,
|
|
|
- min_batch_size=min_batch_size,
|
|
|
- max_batch_size=max_batch_size,
|
|
|
- )
|
|
|
+ experts[expert_uid] = ExpertBackend(
|
|
|
+ name=expert_uid,
|
|
|
+ expert=expert,
|
|
|
+ args_schema=args_schema,
|
|
|
+ optimizer=optim,
|
|
|
+ device=device,
|
|
|
+ fp16=fp16,
|
|
|
+ clip_grad_norm=clip_grad_norm,
|
|
|
+ min_batch_size=min_batch_size,
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ experts[expert_uid] = ExpertBackend(
|
|
|
+ name=expert_uid,
|
|
|
+ expert=expert,
|
|
|
+ args_schema=args_schema,
|
|
|
+ optimizer=optim,
|
|
|
+ device=device,
|
|
|
+ fp16=fp16,
|
|
|
+ clip_grad_norm=clip_grad_norm,
|
|
|
+ min_batch_size=min_batch_size,
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ )
|
|
|
|
|
|
if checkpoint_dir is not None:
|
|
|
load_experts(experts, checkpoint_dir)
|