|
@@ -252,9 +252,8 @@ class Server(threading.Thread):
|
|
},
|
|
},
|
|
]
|
|
]
|
|
|
|
|
|
- optim = OffloadOptimizer(
|
|
|
|
|
|
+ optim = LambWithGradientClipping(
|
|
optimizer_grouped_parameters,
|
|
optimizer_grouped_parameters,
|
|
- optim_cls=LambWithGradientClipping,
|
|
|
|
lr=0.0035355339059327377,
|
|
lr=0.0035355339059327377,
|
|
betas=(0.9, 0.999),
|
|
betas=(0.9, 0.999),
|
|
eps=1e-6,
|
|
eps=1e-6,
|
|
@@ -264,6 +263,18 @@ class Server(threading.Thread):
|
|
debias=True,
|
|
debias=True,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ # optim = OffloadOptimizer(
|
|
|
|
+ # optimizer_grouped_parameters,
|
|
|
|
+ # optim_cls=LambWithGradientClipping,
|
|
|
|
+ # lr=0.0035355339059327377,
|
|
|
|
+ # betas=(0.9, 0.999),
|
|
|
|
+ # eps=1e-6,
|
|
|
|
+ # weight_decay=0.01,
|
|
|
|
+ # max_grad_norm=1,
|
|
|
|
+ # clamp_value=10000.0,
|
|
|
|
+ # debias=True,
|
|
|
|
+ # )
|
|
|
|
+
|
|
expert.to(device)
|
|
expert.to(device)
|
|
|
|
|
|
if use_averaging:
|
|
if use_averaging:
|
|
@@ -274,6 +285,7 @@ class Server(threading.Thread):
|
|
optim,
|
|
optim,
|
|
dht=dht,
|
|
dht=dht,
|
|
prefix=expert_uid.split(UID_DELIMITER)[0],
|
|
prefix=expert_uid.split(UID_DELIMITER)[0],
|
|
|
|
+ scheduler=scheduler,
|
|
compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
state_compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
state_compression=BASE_COMPRESSION_TYPES[averaging_compression],
|
|
target_batch_size=averaging_target_batch_size,
|
|
target_batch_size=averaging_target_batch_size,
|