Explorar el Código

Implement server-side averaging

Max Ryabinin hace 4 años
padre
commit
5a982ab0c5

+ 7 - 0
hivemind/hivemind_cli/run_server.py

@@ -38,6 +38,13 @@ def main():
                         help='Minimum required batch size for all expert operations')
     parser.add_argument('--max_batch_size', type=int, default=16384,
                         help='The total number of examples in the same batch will not exceed this value')
+    parser.add_argument('--use_averaging', action='store_true', help='Whether to use decentralized parameter and '
+                                                                     'gradient averaging by wrapping the optimizer '
+                                                                     'with CollaborativeOptimizer')
+    parser.add_argument('--averaging_target_batch_size', type=int, required=False,
+                        help='Number of examples to accumulate across all peers before averaging')
+    parser.add_argument('--averaging_target_group_size', type=int, required=False,
+                        help='Target group size for decentralized averaging')
     parser.add_argument('--device', type=str, default=None, required=False,
                         help='all experts will use this device in torch notation; default: cuda if available else cpu')
 

+ 31 - 3
hivemind/moe/server/__init__.py

@@ -26,6 +26,7 @@ from hivemind.moe.server.layers import (
     schedule_name_to_scheduler,
 )
 from hivemind.moe.server.runtime import Runtime
+from hivemind.optim import CollaborativeOptimizer
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
@@ -107,6 +108,9 @@ class Server(threading.Thread):
         num_handlers=None,
         min_batch_size=1,
         max_batch_size=4096,
+        use_averaging: bool = False,
+        averaging_target_batch_size: Optional[int] = None,
+        averaging_target_group_size: Optional[int] = None,
         device=None,
         no_dht=False,
         initial_peers=(),
@@ -122,13 +126,17 @@ class Server(threading.Thread):
         :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
         :param num_experts: run this many identical experts
         :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-           means "sample random experts between myprefix.0.0 and myprefix.255.255;
+            means "sample random experts between myprefix.0.0 and myprefix.255.255;
         :param expert_uids: spawn experts with these exact uids, overrides num_experts and expert_pattern
         :param expert_cls: expert type from hivemind.moe.server.layers, e.g. 'ffn' or 'transformer';
         :param hidden_dim: main dimension for expert_cls
         :param num_handlers: server will use this many parallel processes to handle incoming requests
         :param min_batch_size: total num examples in the same batch will be greater than this value
         :param max_batch_size: total num examples in the same batch will not exceed this value
+        :param use_averaging: Whether to use decentralized parameter and gradient averaging by wrapping the optimizer
+            with CollaborativeOptimizer
+        :param averaging_target_batch_size: number of examples to accumulate across all peers before averaging
+        :param averaging_target_group_size: target group size for decentralized averaging
         :param device: all experts will use this device in torch notation; default: cuda if available else cpu
 
         :param optim_cls: uses this optimizer to train all experts
@@ -184,7 +192,11 @@ class Server(threading.Thread):
             uids_to_generate = num_experts - len(expert_uids)
             if uids_to_generate > 0:
                 logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
-                expert_uids.extend(generate_uids_from_pattern(uids_to_generate, expert_pattern, dht))
+                expert_uids.extend(
+                    generate_uids_from_pattern(
+                        uids_to_generate, expert_pattern, dht, remove_duplicates=not use_averaging
+                    )
+                )
 
         num_experts = len(expert_uids)
         num_handlers = num_handlers if num_handlers is not None else num_experts * 8
@@ -203,11 +215,27 @@ class Server(threading.Thread):
         experts = {}
         for expert_uid in expert_uids:
             expert = name_to_block[expert_cls](hidden_dim)
+
+            optim = optim_cls(expert.parameters())
+            if use_averaging:
+                assert averaging_target_batch_size is not None
+                assert averaging_target_group_size is not None
+                optim = CollaborativeOptimizer(
+                    optim,
+                    dht=dht,
+                    prefix=expert_uid.replace(".", ""),
+                    compression_type=compression,
+                    target_batch_size=averaging_target_batch_size,
+                    target_group_size=averaging_target_group_size,
+                    reuse_grad_buffers=True,
+                    start=True,
+                )
+
             experts[expert_uid] = hivemind.ExpertBackend(
                 name=expert_uid,
                 expert=expert,
                 args_schema=args_schema,
-                optimizer=optim_cls(expert.parameters()),
+                optimizer=optim,
                 scheduler=scheduler,
                 num_warmup_steps=num_warmup_steps,
                 num_total_steps=num_total_steps,

+ 9 - 4
hivemind/moe/server/expert_backend.py

@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Sequence, Tuple, Union
 import torch
 from torch import nn
 
+import hivemind
 from hivemind.moe.server.task_pool import TaskPool
 from hivemind.utils.logging import get_logger
 from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
@@ -176,12 +177,16 @@ class ExpertBackend:
         if self.clip_grad_norm is not None:
             torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
 
-        self.optimizer.step()
-        self.optimizer.zero_grad()
+        if isinstance(self.optimizer, hivemind.CollaborativeOptimizer):
+            self.optimizer.step(batch_size)
+        else:
+            self.optimizer.step()
+            self.optimizer.zero_grad()
 
-        if self.scheduler is not None:
-            self.scheduler.step()
+            if self.scheduler is not None:
+                self.scheduler.step()
 
+        # TODO update_count is not always incremented if CollaborativeOptimizer is used
         self.update_count += 1
         self.examples_processed += batch_size
 

+ 14 - 7
hivemind/moe/server/expert_uid.py

@@ -14,6 +14,8 @@ UID_DELIMITER = "."  # when declaring experts, DHT store all prefixes of that ex
 FLAT_EXPERT = -1  # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
 UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$")  # e.g. ffn_expert.98.76.54 - prefix + some dims
 PREFIX_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$")  # e.g. expert. or ffn.45. (ends with ".")
+
+
 #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
 
@@ -35,17 +37,23 @@ def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPref
 
 
 def generate_uids_from_pattern(
-    num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
+    num_experts: int,
+    expert_pattern: Optional[str],
+    dht: Optional[DHT] = None,
+    attempts_per_expert=10,
+    remove_duplicates=True,
 ) -> List[str]:
     """
-    Sample experts from a given pattern, remove duplicates.
+    Sample experts from a given pattern, optionally remove duplicates.
     :param num_experts: sample this many unique expert uids
     :param expert_pattern: a string pattern or a list of expert uids,  example: myprefix.[0:32].[0:256]\
-     means "sample random experts between myprefix.0.0 and myprefix.255.255;
+        means "sample random experts between myprefix.0.0 and myprefix.255.255"
     :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
+    :param dht: whether to exclude expert uids that are already present in the DHT
+        (you may disable it if you want to have the same expert on multiple peers)
     :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
     :note: this method is not strictly process-safe. If several servers run it concurrently, they have
-     a small chance of sampling duplicate expert uids.
+        a small chance of sampling duplicate expert uids.
     """
     remaining_attempts = attempts_per_expert * num_experts
     found_uids, attempted_uids = list(), set()
@@ -72,7 +80,7 @@ def generate_uids_from_pattern(
 
     while remaining_attempts > 0 and len(found_uids) < num_experts:
 
-        # 1. sample new expert uids at random
+        # sample new expert uids at random
         new_uids = []
         while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
             new_uid = _generate_uid()
@@ -81,8 +89,7 @@ def generate_uids_from_pattern(
                 attempted_uids.add(new_uid)
                 new_uids.append(new_uid)
 
-        # 2. look into DHT (if given) and remove duplicates
-        if dht:
+        if dht and remove_duplicates:
             existing_expert_uids = {
                 found_expert.uid
                 for found_expert in hivemind.moe.server.get_experts(dht, new_uids)

+ 4 - 1
hivemind/optim/collaborative.py

@@ -400,7 +400,10 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         current_time = get_dht_time()
 
         if not isinstance(response, dict) or len(response) == 0:
-            logger.log(self.status_loglevel, f"Found no active peers: {response}")
+            logger.log(
+                self.status_loglevel,
+                f"Collaboration {self.prefix} found no active peers {f': {response}' if response else ''}",
+            )
             local_eta_next_step = (
                 max(0, self.target_batch_size - self.local_steps_accumulated) / self.performance_ema.samples_per_second
             )