Ver código fonte

Add optional offload

Max Ryabinin 3 anos atrás
pai
commit
b26d61b1c4
2 arquivos alterados com 44 adições e 31 exclusões
  1. 1 0
      hivemind/hivemind_cli/run_server.py
  2. 43 31
      hivemind/moe/server/__init__.py

+ 1 - 0
hivemind/hivemind_cli/run_server.py

@@ -48,6 +48,7 @@ def main():
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
     parser.add_argument('--fp16',action='store_true',help='Use mixed precision during forward and backward steps')
+    parser.add_argument('--offload',action='store_true',help='Use mixed precision during forward and backward steps')
 
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',

+ 43 - 31
hivemind/moe/server/__init__.py

@@ -26,7 +26,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
-from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping
+from hivemind.optim import CollaborativeOptimizer, LambWithGradientClipping, OffloadOptimizer
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
@@ -120,6 +120,7 @@ class Server(threading.Thread):
         reuse_grad_buffers=True,
         device=None,
         fp16=False,
+        offload=False,
         no_dht=False,
         dht_port=None,
         dht_listen_on=None,
@@ -252,8 +253,7 @@ class Server(threading.Thread):
                 },
             ]
 
-            optim = LambWithGradientClipping(
-                optimizer_grouped_parameters,
+            optim_kwargs = dict(
                 lr=0.0035355339059327377,
                 betas=(0.9, 0.999),
                 eps=1e-6,
@@ -263,31 +263,31 @@ class Server(threading.Thread):
                 debias=True,
             )
 
-            scheduler = scheduler(optim, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)
+            if offload:
+                optim = OffloadOptimizer(
+                    optimizer_grouped_parameters,
+                    optim_cls=LambWithGradientClipping,
+                    **optim_kwargs
+                )
+            else:
+                optim = LambWithGradientClipping(
+                    optimizer_grouped_parameters,
+                    **optim_kwargs
+                )
 
-            # optim = OffloadOptimizer(
-            #     optimizer_grouped_parameters,
-            #     optim_cls=LambWithGradientClipping,
-            #     lr=0.0035355339059327377,
-            #     betas=(0.9, 0.999),
-            #     eps=1e-6,
-            #     weight_decay=0.01,
-            #     max_grad_norm=1,
-            #     clamp_value=10000.0,
-            #     debias=True,
-            # )
+            scheduler = scheduler(optim, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)
 
             expert.to(device)
 
-            averaging_compression = SizeAdaptiveCompression(
-                threshold=2 ** 16 + 1, less=Float16Compression(),
-                greater_equal=Uniform8BitQuantization()
-            )
-
             if use_averaging:
                 assert averaging_target_batch_size is not None
                 assert averaging_target_group_size is not None
 
+                averaging_compression = SizeAdaptiveCompression(
+                    threshold=2 ** 16 + 1, less=Float16Compression(),
+                    greater_equal=Uniform8BitQuantization()
+                )
+
                 optim = CollaborativeOptimizer(
                     optim,
                     dht=dht,
@@ -309,17 +309,29 @@ class Server(threading.Thread):
                 )
                 optim.load_state_from_peers()
 
-            experts[expert_uid] = ExpertBackend(
-                name=expert_uid,
-                expert=expert,
-                args_schema=args_schema,
-                optimizer=optim,
-                device=device,
-                fp16=fp16,
-                clip_grad_norm=clip_grad_norm,
-                min_batch_size=min_batch_size,
-                max_batch_size=max_batch_size,
-            )
+                experts[expert_uid] = ExpertBackend(
+                    name=expert_uid,
+                    expert=expert,
+                    args_schema=args_schema,
+                    optimizer=optim,
+                    device=device,
+                    fp16=fp16,
+                    clip_grad_norm=clip_grad_norm,
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
+            else:
+                experts[expert_uid] = ExpertBackend(
+                    name=expert_uid,
+                    expert=expert,
+                    args_schema=args_schema,
+                    optimizer=optim,
+                    device=device,
+                    fp16=fp16,
+                    clip_grad_norm=clip_grad_norm,
+                    min_batch_size=min_batch_size,
+                    max_batch_size=max_batch_size,
+                )
 
         if checkpoint_dir is not None:
             load_experts(experts, checkpoint_dir)