Ver código fonte

Improve user-friendliness and fix misc errors (#428)

This PR introduces two quality-of-life changes found with the help of community members

1. Warn if there are too many no-grad params (Found by mr_seeker@ from [KoboldAI](https://github.com/KoboldAI))

The problem manifested itself when he created an optimizer over GPT-J-6B with most parameters set to requires_grad_(False).

Currently, TrainingStateAverager will still average these parameters with peers to properly average batchnorm EMAs and similar statistics. As a result, TrainingStateAverager would OOM after trying to initialize buffers for all gpt-j parameters.

This PR checks if the majority of parameters are non-trainable and prints a warning with instructions.

2. Do not count metric loggers as training peers (found with the help of CALM collaborators)

Currently, non-gpu peers that are responsible for metric logging periodically report having accumulated 0 samples at eps samples/sec. This causes confusion because these peers do not actually compute minibatches and are NOT counted in "averaged gradients with X peers".

3. found in @bawr 's logs
Optimizer with DPU and HivemindGradScaler would trigger an error if gradient overflow was detected during DPU

4. found by @pr-mais : setting ANNOUNCE_MADDRS with port 0 would silently break matchmaking with some probability
The reason behind this is that P2P does not auto-resolve port 0 in announce_maddrs. We will now raise an error if user sets up a P2P with this configuration error

5. found by @borzunov  - Reducer error when handling network error
This is a non-fatal error that happens on reducer when some senders take too long to send their data and are killed -- but their last message does make it.



Co-authored-by: Mais Alheraki <mais@fairybits.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Julius ter Pelkwijk <1099127+mrseeker@users.noreply.github.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 anos atrás
pai
commit
595b831bca

+ 9 - 5
hivemind/averaging/allreduce.py

@@ -4,7 +4,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Typ
 
 import torch
 
-from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
+from hivemind.averaging.partition import AllreduceException, BannedException, TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
@@ -343,10 +343,14 @@ class AllReduceRunner(ServicerBase):
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
             ):
-                averaged_part = await self.tensor_part_reducer.accumulate_part(
-                    sender_index, part_index, tensor_part, weight=weight
-                )
-                part_index += 1
+                try:
+                    averaged_part = await self.tensor_part_reducer.accumulate_part(
+                        sender_index, part_index, tensor_part, weight=weight
+                    )
+                    part_index += 1
+                except BannedException:
+                    logger.debug(f"Sender {sender_index} is already banned")
+                    break  # sender was banned, we no longer need to aggregate it
 
                 serialized_delta = await loop.run_in_executor(
                     None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)

+ 2 - 4
hivemind/averaging/matchmaking.py

@@ -9,8 +9,6 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
-import numpy as np
-
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
@@ -203,7 +201,7 @@ class Matchmaking:
                 message = await asyncio.wait_for(anext(stream), timeout=self.request_timeout)
 
                 if message.code == averaging_pb2.ACCEPTED:
-                    logger.debug(f"{self.peer_id} - joining the group of {leader}; waiting for peers")
+                    logger.debug(f"{self.peer_id} - joining the group of {leader}, waiting for peers")
                     self.current_leader = leader
                     self.was_accepted_to_group.set()
                     if len(self.current_followers) > 0:
@@ -242,7 +240,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
         except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
-            logger.debug(f"{self} - failed to request potential leader {leader}:")
+            logger.debug(f"{self} - failed to request potential leader {leader}:", exc_info=True)
             return None
 
         finally:

+ 9 - 0
hivemind/averaging/partition.py

@@ -227,6 +227,9 @@ class TensorPartReducer:
             await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
             if self.finished.is_set():
                 raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")
+
+        if self.sender_failed_after[sender_index] != float("inf"):
+            raise BannedException(f"sender {sender_index} was banned in background")
         assert part_index == self.current_part_index
 
         current_part_future = self.current_part_future
@@ -241,6 +244,8 @@ class TensorPartReducer:
     def on_sender_failed(self, sender_index: int):
         """Exclude that sender's data for averaging any parts that it did not submit yet."""
         self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.finished.is_set():
+            return
         if self.current_part_index == self.num_parts_received[sender_index]:
             self.num_current_senders -= 1
             self.check_current_part_finished()
@@ -270,3 +275,7 @@ class TensorPartReducer:
 
 class AllreduceException(Exception):
     """A special exception that is raised when allreduce can't continue normally (e.g. disconnected/protocol error)"""
+
+
+class BannedException(AllreduceException):
+    """An exception that indicates that a given sender was banned and will no longer be aggregated"""

+ 16 - 5
hivemind/optim/grad_scaler.py

@@ -35,6 +35,7 @@ class GradScaler(TorchGradScaler):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._is_ready_to_update = False
+        self._inner_optimizer_states = {}
         self._optimizer_states_to_reset = set()
         self._lock = threading.RLock()
 
@@ -52,7 +53,12 @@ class GradScaler(TorchGradScaler):
             assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             if self._is_running_global_step:
                 super().unscale_(optimizer)
-                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                # note: we store unscaled optimizer state in a separate dict and not in _per_optimizer_states in order
+                # to avoid an edge case where full DPU peer encounters overflow in local gradients while averaging
+                # offloaded gradients (i.e. after global unscale but before global step). Due to overflow, next call to
+                # .update on user side would reset *all* optimizer states and cause .step to unscale gradients twice.
+                # Offloaded optimizer is not affected by overflow in on-device gradients and should not be reset.
                 return True
             else:
                 self._check_inf_per_device(optimizer)
@@ -62,14 +68,19 @@ class GradScaler(TorchGradScaler):
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
             # ^-- invoked privately within hivemind optimizer
+            inner_optimizer = optimizer
             with self._lock:
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step")
+
+                inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
+                if inner_optimizer_state is not None:
+                    self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
                 assert (
-                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
-                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
-                if self.are_grads_finite(optimizer, use_cached=True):
-                    super().step(optimizer, *args, **kwargs)
+                    self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
+                if self.are_grads_finite(inner_optimizer, use_cached=True):
+                    super().step(inner_optimizer, *args, **kwargs)
                 else:
                     logger.warning("Skipping global step due to gradient over/underflow")
                 self._is_ready_to_update = True

+ 17 - 12
hivemind/optim/progress_tracker.py

@@ -195,6 +195,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
+        last_report_epoch = -float("inf")
         store_task = None
         try:
             while not self.shutdown_triggered.is_set():
@@ -209,19 +210,23 @@ class ProgressTracker(threading.Thread):
 
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
-
-                store_task = asyncio.create_task(
-                    asyncio.wait_for(
-                        self.dht.store(
-                            key=self.training_progress_key,
-                            subkey=self._local_public_key,
-                            value=local_progress.dict(),
-                            expiration_time=last_report_time + self.metadata_expiration,
-                            return_future=True,
-                        ),
-                        timeout=self.metadata_expiration,
+                if local_progress.samples_accumulated > 0:
+                    last_report_epoch = self.global_epoch
+
+                if last_report_epoch >= self.global_epoch - 1:
+                    # report progress if peer is synchronized and actively reporting samples. Do not report aux peers.
+                    store_task = asyncio.create_task(
+                        asyncio.wait_for(
+                            self.dht.store(
+                                key=self.training_progress_key,
+                                subkey=self._local_public_key,
+                                value=local_progress.dict(),
+                                expiration_time=last_report_time + self.metadata_expiration,
+                                return_future=True,
+                            ),
+                            timeout=self.metadata_expiration,
+                        )
                     )
-                )
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
             if store_task is not None:

+ 9 - 0
hivemind/optim/state_averager.py

@@ -152,6 +152,15 @@ class TrainingStateAverager(DecentralizedAverager):
         parameter_names = tuple(nested_flatten(parameter_names))
         assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
+        params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
+        if params_no_grad >= params_with_grad:
+            logger.warning(
+                "The majority of parameters have requires_grad=False, but they are still synchronized"
+                " with peers. If these parameters are frozen (not updated), please do not feed them into "
+                "the optimizer at all in order to avoid communication overhead. Proceeding anyway."
+            )
+
         return param_groups, parameters, parameter_names
 
     def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):

+ 5 - 0
hivemind/p2p/p2p_daemon.py

@@ -140,6 +140,11 @@ class P2P:
         socket_uid = secrets.token_urlsafe(8)
         self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
+        if announce_maddrs is not None:
+            for addr in announce_maddrs:
+                addr = Multiaddr(addr)
+                if ("tcp" in addr and addr["tcp"] == "0") or ("udp" in addr and addr["udp"] == "0"):
+                    raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
         process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})