Explorar el Código

Add gradient clipping support to ExpertBackend (#214)

Max Ryabinin hace 4 años
padre
commit
6128cbbd51

+ 5 - 2
hivemind/hivemind_cli/run_server.py

@@ -40,8 +40,9 @@ def main():
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
                         help='LR scheduler type to use')
-    parser.add_argument('--num-warmup-steps', type=int, required=False, help='the number of warmup steps for LR schedule')
-    parser.add_argument('--num-training-steps', type=int, required=False, help='the total number of steps for LR schedule')
+    parser.add_argument('--num_warmup_steps', type=int, required=False, help='The number of warmup steps for LR schedule')
+    parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
+    parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
@@ -53,6 +54,8 @@ def main():
     parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression '
                         'parameter for grpc. Can be NONE, MEANSTD or FLOAT16')
     parser.add_argument('--checkpoint_dir', type=Path, required=False, help='Directory to store expert checkpoints')
+    parser.add_argument('--stats_report_interval', type=int, required=False,
+                        help='Interval between two reports of batch processing performance statistics')
 
     # fmt:on
     args = vars(parser.parse_args())

+ 10 - 6
hivemind/server/__init__.py

@@ -70,9 +70,10 @@ class Server(threading.Thread):
     @classmethod
     def create(cls, listen_on='0.0.0.0:*', num_experts: int = None, expert_uids: str = None, expert_pattern: str = None,
                expert_cls='ffn', hidden_dim=1024, optim_cls=torch.optim.Adam, scheduler: str = 'none',
-               num_warmup_steps=None, num_training_steps=None, num_handlers=None, max_batch_size=4096, device=None,
-               no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
-               compression=CompressionType.NONE, *, start: bool, **kwargs) -> Server:
+               num_warmup_steps=None, num_total_steps=None, clip_grad_norm=None, num_handlers=None, max_batch_size=4096,
+               device=None, no_dht=False, initial_peers=(), dht_port=None, checkpoint_dir: Optional[Path] = None,
+               compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, *, start: bool,
+               **kwargs) -> Server:
         """
         Instantiate a server with several identical experts. See argparse comments below for details
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -89,7 +90,8 @@ class Server(threading.Thread):
         :param optim_cls: uses this optimizer to train all experts
         :param scheduler: if not `none`, the name of the expert LR scheduler
         :param num_warmup_steps: the number of warmup steps for LR schedule
-        :param num_training_steps: the total number of steps for LR schedule
+        :param num_total_steps: the total number of steps for LR schedule
+        :param clip_grad_norm: maximum gradient norm used for clipping
 
         :param no_dht: if specified, the server will not be attached to a dht
         :param initial_peers: a list of peers that will introduce this node to the dht,\
@@ -105,6 +107,7 @@ class Server(threading.Thread):
             for each BatchTensorProto in ExpertBackend for the respective experts.
 
         :param start: if True, starts server right away and returns when server is ready for requests
+        :param stats_report_interval: interval between two reports of batch processing performance statistics
         """
         if len(kwargs) != 0:
             logger.info("Ignored kwargs:", kwargs)
@@ -165,14 +168,15 @@ class Server(threading.Thread):
                                                          optimizer=optim_cls(expert.parameters()),
                                                          scheduler=scheduler,
                                                          num_warmup_steps=num_warmup_steps,
-                                                         num_training_steps=num_training_steps,
+                                                         num_total_steps=num_total_steps,
+                                                         clip_grad_norm=clip_grad_norm,
                                                          max_batch_size=max_batch_size)
 
         if checkpoint_dir is not None:
             load_experts(experts, checkpoint_dir)
 
         return cls(dht, experts, listen_on=listen_on, num_connection_handlers=num_handlers, device=device,
-                   checkpoint_dir=checkpoint_dir, start=start)
+                   checkpoint_dir=checkpoint_dir, stats_report_interval=stats_report_interval, start=start)
 
     def run(self):
         """

+ 9 - 4
hivemind/server/expert_backend.py

@@ -35,7 +35,8 @@ class ExpertBackend:
     :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
     :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
     :param num_warmup_steps: the number of warmup steps for LR schedule
-    :param num_training_steps: the total number of steps for LR schedule
+    :param num_total_steps: the total number of steps for LR schedule
+    :param clip_grad_norm: maximum gradient norm used for clipping
     :param kwargs: extra parameters to be forwarded into TaskPool.__init__
     """
 
@@ -44,7 +45,7 @@ class ExpertBackend:
                  args_schema: Tuple[BatchTensorDescriptor, ...] = None,
                  kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
                  outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
-                 num_warmup_steps: int = None, num_training_steps: int = None,
+                 num_warmup_steps: int = None, num_total_steps: int = None, clip_grad_norm: float = None,
                  **kwargs):
         super().__init__()
         self.expert, self.optimizer, self.name = expert, optimizer, name
@@ -52,8 +53,9 @@ class ExpertBackend:
         if scheduler is None:
             self.scheduler = None
         else:
-            assert optimizer is not None and num_warmup_steps is not None and num_training_steps is not None
-            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_training_steps)
+            assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
+            self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
+        self.clip_grad_norm = clip_grad_norm
 
         self.args_schema = args_schema = tuple(args_schema or ())
         self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
@@ -147,6 +149,9 @@ class ExpertBackend:
         """
         Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
         """
+        if self.clip_grad_norm is not None:
+            torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
+
         self.optimizer.step()
         self.optimizer.zero_grad()
 

+ 16 - 7
hivemind/server/runtime.py

@@ -7,7 +7,7 @@ from queue import SimpleQueue
 from selectors import DefaultSelector, EVENT_READ
 from statistics import mean
 from time import time
-from typing import Dict, NamedTuple
+from typing import Dict, NamedTuple, Optional
 
 import torch
 from prefetch_generator import BackgroundGenerator
@@ -43,7 +43,7 @@ class Runtime(threading.Thread):
     """
 
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
-                 device: torch.device = None, stats_report_interval=30):
+                 device: torch.device = None, stats_report_interval: Optional[int] = None):
         super().__init__()
         self.expert_backends = expert_backends
         self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values())))
@@ -51,7 +51,9 @@ class Runtime(threading.Thread):
         self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
 
-        self.stats_reporter = StatsReporter(stats_report_interval)
+        self.stats_report_interval = stats_report_interval
+        if self.stats_report_interval is not None:
+            self.stats_reporter = StatsReporter(self.stats_report_interval)
 
     def run(self):
         for pool in self.pools:
@@ -64,8 +66,10 @@ class Runtime(threading.Thread):
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
                 self.ready.set()
-                self.stats_reporter.start()
+                if self.stats_report_interval is not None:
+                    self.stats_reporter.start()
                 logger.info("Started")
+
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
                     logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
@@ -76,13 +80,18 @@ class Runtime(threading.Thread):
 
                     batch_size = outputs[0].size(0)
                     logger.debug(f"Pool {pool.uid}: batch {batch_index} processed, size {batch_size}")
-                    self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
+
+                    if self.stats_report_interval is not None:
+                        self.stats_reporter.report_stats(pool.uid, batch_size, batch_processing_time)
 
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
             finally:
                 logger.info("Shutting down")
-                self.stats_reporter.stop.set()
-                self.stats_reporter.join()
+
+                if self.stats_report_interval is not None:
+                    self.stats_reporter.stop.set()
+                    self.stats_reporter.join()
+
                 self.shutdown()
 
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"

+ 1 - 1
tests/test_expert_backend.py

@@ -25,7 +25,7 @@ def example_experts():
     expert_backend = ExpertBackend(name=EXPERT_NAME, expert=expert, optimizer=opt,
                                    scheduler=get_linear_schedule_with_warmup,
                                    num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE,
-                                   num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
+                                   num_total_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE,
                                    args_schema=args_schema, outputs_schema=BatchTensorDescriptor(1), max_batch_size=1,
                                    )
     experts = {EXPERT_NAME: expert_backend}