Переглянути джерело

Merge branch 'master' into rfc_optimizer

justheuristic 3 роки тому
батько
коміт
abbea26c9c

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.png


+ 2 - 1
hivemind/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
@@ -16,6 +16,7 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 10 - 11
hivemind/averaging/averager.py

@@ -32,7 +32,15 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -453,7 +461,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
@@ -505,15 +513,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
 
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
-
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:

+ 1 - 1
hivemind/averaging/control.py

@@ -103,7 +103,7 @@ class StepControl(MPFuture):
     @stage.setter
     def stage(self, stage: AveragingStage):
         if stage == AveragingStage.RUNNING_ALLREDUCE:
-            self.can_modify = False
+            self.began_allreduce = True
         self._shared_buffer[StepControl._STAGE] = stage.value
 
     @property

+ 3 - 322
hivemind/dht/__init__.py

@@ -4,7 +4,7 @@ Hivemind DHT is based on Kademlia [1] with added support for improved bulk store
 
 The code is organized as follows:
 
- * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
+ * **class DHT (dht.py)** - high-level class for model training. Runs DHTNode in a background process.
  * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
@@ -12,327 +12,8 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
-from __future__ import annotations
-
-import asyncio
-import multiprocessing as mp
-import os
-from functools import partial
-from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
-
-from multiaddr import Multiaddr
 
+from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
-
-logger = get_logger(__name__)
-
-ReturnType = TypeVar("ReturnType")
-
-
-class DHT(mp.Process):
-    """
-    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
-    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
-    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
-
-    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
-    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
-      (but no more than one per key)
-    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
-      The validators will be combined using the CompositeValidator class. It merges them when possible
-      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
-    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
-    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
-    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
-    """
-
-    _node: DHTNode
-
-    def __init__(
-        self,
-        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        *,
-        start: bool,
-        p2p: Optional[P2P] = None,
-        daemon: bool = True,
-        num_workers: int = DEFAULT_NUM_WORKERS,
-        record_validators: Iterable[RecordValidatorBase] = (),
-        shutdown_timeout: float = 3,
-        await_ready: bool = True,
-        **kwargs,
-    ):
-        self._parent_pid = os.getpid()
-        super().__init__()
-
-        if not (
-            initial_peers is None
-            or (
-                isinstance(initial_peers, Sequence)
-                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
-            )
-        ):
-            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
-        self.initial_peers = initial_peers
-        self.kwargs = kwargs
-        self.num_workers = num_workers
-
-        self._record_validator = CompositeValidator(record_validators)
-        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
-        self.shutdown_timeout = shutdown_timeout
-        self._ready = MPFuture()
-        self.daemon = daemon
-
-        # These values will be fetched from the child process when requested
-        self._peer_id = None
-        self._client_mode = None
-        self._p2p_replica = None
-
-        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
-
-        if start:
-            self.run_in_background(await_ready=await_ready)
-
-    def run(self) -> None:
-        """Serve DHT forever. This function will not return until DHT node is shut down"""
-
-        loop = switch_to_uvloop()
-        pipe_semaphore = asyncio.Semaphore(value=0)
-        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
-
-        async def _run():
-            try:
-                if self._daemon_listen_maddr is not None:
-                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                else:
-                    replicated_p2p = None
-
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    p2p=replicated_p2p,
-                    **self.kwargs,
-                )
-            except Exception as e:
-                # Loglevel is DEBUG since normally the exception is propagated to the caller
-                logger.debug(e, exc_info=True)
-                self._ready.set_exception(e)
-                return
-            self._ready.set_result(None)
-
-            while True:
-                try:
-                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
-                except asyncio.TimeoutError:
-                    pass
-                if not self._inner_pipe.poll():
-                    continue
-                try:
-                    method, args, kwargs = self._inner_pipe.recv()
-                except (OSError, ConnectionError, RuntimeError) as e:
-                    logger.exception(e)
-                    await asyncio.sleep(self._node.protocol.wait_timeout)
-                    continue
-                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                if method == "_shutdown":
-                    await task
-                    break
-
-        loop.run_until_complete(_run())
-
-    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
-        """
-        Starts DHT in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready:
-            self.wait_until_ready(timeout)
-
-    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
-        self._ready.result(timeout=timeout)
-
-    def shutdown(self) -> None:
-        """Shut down a running dht process"""
-        if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [], {}))
-            self.join(self.shutdown_timeout)
-            if self.is_alive():
-                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
-                self.terminate()
-
-    async def _shutdown(self):
-        await self._node.shutdown()
-
-    def get(
-        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
-    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
-        """
-        Search for a key across DHT and return either first or latest entry (if found).
-        :param key: same key as in node.store(...)
-        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns None
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
-        return future if return_future else future.result()
-
-    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
-        try:
-            result = await self._node.get(key, latest=latest, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey] = None,
-        return_future: bool = False,
-        **kwargs,
-    ) -> Union[bool, MPFuture]:
-        """
-        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-
-        :param key: msgpack-serializable key to be associated with value until expiration.
-        :param value: msgpack-serializable value to be stored under a given key until expiration.
-        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
-        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: True if store succeeds, False if it fails (due to no response or newer value)
-        """
-        future = MPFuture()
-        self._outer_pipe.send(
-            (
-                "_store",
-                [],
-                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
-            )
-        )
-        return future if return_future else future.result()
-
-    async def _store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey],
-        future: MPFuture,
-        **kwargs,
-    ):
-        try:
-            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
-    ) -> Union[ReturnType, MPFuture[ReturnType]]:
-        """
-        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
-         for running custom functions DHT for special cases (e.g. declare experts, beam search)
-
-        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: coroutine outputs or MPFuture for these outputs
-        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
-          DHT fields made by this coroutine will not be accessible from the host process.
-        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
-          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
-        return future if return_future else future.result()
-
-    async def _run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
-    ):
-        try:
-            future.set_result(await coro(self, self._node))
-        except BaseException as e:
-            logger.exception("Caught an exception when running a coroutine:")
-            future.set_exception(e)
-
-    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self._ready.done():
-            raise RuntimeError(
-                "Can't append new validators before the DHT process has started. "
-                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
-            )
-
-        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
-
-    @staticmethod
-    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
-        node.protocol.record_validator.extend(record_validators)
-
-    @property
-    def peer_id(self) -> PeerID:
-        if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
-        return self._peer_id
-
-    @staticmethod
-    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
-        return node.peer_id
-
-    @property
-    def client_mode(self) -> bool:
-        if self._client_mode is None:
-            self._client_mode = self.run_coroutine(DHT._get_client_mode)
-        return self._client_mode
-
-    @staticmethod
-    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
-        return node.protocol.client_mode
-
-    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
-        """
-        Get multiaddrs of the current DHT node that should be accessible by other peers.
-
-        :param latest: ask the P2P daemon to refresh the visible multiaddrs
-        """
-
-        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
-
-    @staticmethod
-    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
-        return await node.get_visible_maddrs(latest=latest)
-
-    async def replicate_p2p(self) -> P2P:
-        """
-        Get a replica of a P2P instance used in the DHT process internally.
-        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
-        """
-
-        if self._p2p_replica is None:
-            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
-        return self._p2p_replica
-
-    @staticmethod
-    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
-        return node.p2p.daemon_listen_maddr
-
-    def __del__(self):
-        if self._parent_pid == os.getpid() and self.is_alive():
-            self.shutdown()

+ 324 - 0
hivemind/dht/dht.py

@@ -0,0 +1,324 @@
+from __future__ import annotations
+
+import asyncio
+import multiprocessing as mp
+import os
+from functools import partial
+from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
+
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration
+
+logger = get_logger(__name__)
+ReturnType = TypeVar("ReturnType")
+
+
+class DHT(mp.Process):
+    """
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
+
+    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
+    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
+      (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
+      The validators will be combined using the CompositeValidator class. It merges them when possible
+      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
+    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
+    """
+
+    _node: DHTNode
+
+    def __init__(
+        self,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        *,
+        start: bool,
+        p2p: Optional[P2P] = None,
+        daemon: bool = True,
+        num_workers: int = DEFAULT_NUM_WORKERS,
+        record_validators: Iterable[RecordValidatorBase] = (),
+        shutdown_timeout: float = 3,
+        await_ready: bool = True,
+        **kwargs,
+    ):
+        self._parent_pid = os.getpid()
+        super().__init__()
+
+        if not (
+            initial_peers is None
+            or (
+                isinstance(initial_peers, Sequence)
+                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
+            )
+        ):
+            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
+        self.initial_peers = initial_peers
+        self.kwargs = kwargs
+        self.num_workers = num_workers
+
+        self._record_validator = CompositeValidator(record_validators)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
+        self._ready = MPFuture()
+        self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
+        if start:
+            self.run_in_background(await_ready=await_ready)
+
+    def run(self) -> None:
+        """Serve DHT forever. This function will not return until DHT node is shut down"""
+
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts DHT in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
+
+    def shutdown(self) -> None:
+        """Shut down a running dht process"""
+        if self.is_alive():
+            self._outer_pipe.send(("_shutdown", [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
+
+    async def _shutdown(self):
+        await self._node.shutdown()
+
+    def get(
+        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await self._node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey] = None,
+        return_future: bool = False,
+        **kwargs,
+    ) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future = MPFuture()
+        self._outer_pipe.send(
+            (
+                "_store",
+                [],
+                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
+            )
+        )
+        return future if return_future else future.result()
+
+    async def _store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey],
+        future: MPFuture,
+        **kwargs,
+    ):
+        try:
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
+    ) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
+    ):
+        try:
+            future.set_result(await coro(self, self._node))
+        except BaseException as e:
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
+
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self._ready.done():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
+            )
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
+    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        """
+        Get multiaddrs of the current DHT node that should be accessible by other peers.
+
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
+        """
+
+        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
+
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+        return await node.get_visible_maddrs(latest=latest)
+
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 1 - 1
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 from hivemind.p2p import PeerID
-from hivemind.utils import MSGPackSerializer, get_dht_time
+from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
 
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes

+ 1 - 0
hivemind/optim/__init__.py

@@ -3,3 +3,4 @@ from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -9,14 +9,14 @@ import numpy as np
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 0 - 0
hivemind/optim/experimental/__init__.py


+ 219 - 0
hivemind/optim/experimental/grad_averager.py

@@ -0,0 +1,219 @@
+import contextlib
+from typing import Iterable, Iterator, Optional
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.utils import DHTExpiration, get_logger
+
+logger = get_logger(__name__)
+
+
+class GradientAverager(DecentralizedAverager):
+    """
+    An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
+    GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
+
+    GradientAverager manages three sets of buffers:
+    (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
+        These tensors are typically stored on device and updated by torch autograd
+    (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
+      - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
+        which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
+    (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
+
+
+    Example:
+
+    >>> model = SuchModelMuchLayers()
+    >>> opt = torch.optim.Adam(model.parameters())
+    >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
+    >>> next_step_time = hivemind.get_dht_time() + 60   # runs global steps every 60 seconds
+    >>> next_step_control = None
+    >>> while True:
+    >>>    # accumulate as many gradients as you can before next_step_time
+    >>>    loss = compute_loss(model, batch_size=32)
+    >>>    loss.backward()
+    >>>    grad_averager.accumulate_grads_(batch_size=32)
+    >>>    # [optional] next step in 5 seconds, start looking for peers in advance
+    >>>    if next_step_time - hivemind.get_dht_time() <= 5
+    >>>        next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
+    >>>    # aggregate gradients and perform optimizer step
+    >>>    if hivemind.get_dht_time() >= next_step_time:
+    >>>        grad_averager.step(control=next_step_control)
+    >>>        with grad_averager.use_averaged_gradients():  # this will fill param.grads with aggregated gradients
+    >>>            opt.step()  # update model parameters using averaged gradients
+    >>>        grad_averager.reset_accumulated_grads_()  # prepare for next step
+    >>>        next_step_time = hivemind.get_dht_time() + 60
+    >>>        next_step_control = None
+
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self._parameters = tuple(parameters)
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.warn = warn
+        self.local_samples_accumulated = 0
+        self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        self._local_accumulators = None
+        if not reuse_grad_buffers:
+            self._local_accumulators = tuple(
+                torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
+            )
+        self._accumulators_used_in_step = False
+        self._new_averaged_grads = False
+
+        with torch.no_grad():
+            averaged_grads = tuple(
+                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+            )
+        super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
+
+    def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
+        """gradient buffers associated with parameters"""
+        for param in self._parameters:
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            yield param.grad
+
+    @torch.no_grad()
+    def _grad_accumulators(self) -> Iterator[torch.Tensor]:
+        """averager-based gradient accumulators"""
+        assert (self._local_accumulators is None) == self.reuse_grad_buffers
+        yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """add current gradients to local grad accumulators (if used)"""
+        if self._accumulators_used_in_step and self.warn:
+            logger.warning(
+                "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)."
+            )
+            self._accumulators_used_in_step = False  # warn once per round
+        if self._anchor_batch_size is None:
+            # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
+            self._anchor_batch_size = batch_size
+        self.local_samples_accumulated += batch_size
+        self.local_times_accumulated += 1
+        if self.reuse_grad_buffers:
+            pass  # user is responsible for accumulating gradients in .grad buffers
+        else:
+            alpha = float(batch_size) / self._anchor_batch_size
+            for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        weight: Optional[float] = None,
+        reset_accumulators: bool = True,
+        control: Optional[StepControl] = None,
+        wait: bool = True,
+        **kwargs,
+    ):
+        """
+        Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
+
+        :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
+        :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
+        :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
+        """
+        if control is None:
+            control = self.schedule_step(**kwargs)
+        elif len(kwargs) > 0:
+            RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
+        assert not control.triggered, f"This {type(control)} instance was already used."
+        self._load_accumulators_into_averager_()
+        self._accumulators_used_in_step = True
+        self._new_averaged_grads = True
+
+        control.weight = self.local_samples_accumulated if weight is None else weight
+        if reset_accumulators:
+            self.reset_accumulated_grads_()
+
+        control.allow_allreduce()
+        return control.result() if wait else control
+
+    @torch.no_grad()
+    def _load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+            self._new_averaged_grads = False  # warn once per round
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
+        with self.get_tensors() as averaged_grads:
+            for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
+                averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        """reset averager-internal gradient accumulators and the denominator"""
+        self._accumulators_used_in_step = False
+        self.local_samples_accumulated = self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        for grad_buf in self._grad_accumulators():
+            grad_buf.zero_()
+
+    @contextlib.contextmanager
+    @torch.no_grad()
+    def use_averaged_gradients(self):
+        self._new_averaged_grads = False
+        with self.get_tensors() as averaged_grads:
+            try:
+                assert len(averaged_grads) == len(self._parameters)
+                old_grads = [param.grad for param in self._parameters]
+                for param, new_grad in zip(self._parameters, averaged_grads):
+                    param.grad = new_grad
+                yield
+            finally:
+                for param, old_grad in zip(self._parameters, old_grads):
+                    param.grad = old_grad

+ 569 - 0
hivemind/optim/experimental/state_averager.py

@@ -0,0 +1,569 @@
+""" An extension of averager that supports common optimization use cases. """
+import logging
+from asyncio import Future
+from concurrent.futures import ThreadPoolExecutor
+from itertools import chain
+from threading import Event
+from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
+
+import torch
+
+import hivemind
+from hivemind import nested_compare
+from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+
+logger = get_logger(__name__)
+
+
+Parameters = Iterable[torch.Tensor]
+ParamGroups = Iterable[Dict[str, Any]]
+TorchOptimizer = torch.optim.Optimizer
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
+SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
+
+
+class TrainingStateAverager(DecentralizedAverager):
+    """
+    An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
+    and any other variables that define the local training state (e.g. batchnorm moving averages).
+    TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
+
+    The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
+    If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
+
+    Example:
+
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
+    >>> avgr.load_state_from_peers()
+    >>> for i, batch in enumerate(training_dataloader):
+    >>>     loss = compute_loss(model, batch)
+    >>>     loss.backward()
+    >>>     avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
+
+    :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
+      it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
+      TrainingStateAverager.step(..., optimizer_step=True)
+
+    :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
+    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
+    :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
+    :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
+      state tensors. If False, user must make sure that all tensors are pre-initialized at init.
+      By default, initialize optimizer unless it already has some state tensors to begin with.
+    :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
+    :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
+      This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
+    :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
+      For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+    :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
+    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        initialize_optimizer: Optional[bool] = None,
+        offload_optimizer: bool = False,
+        custom_gradients: bool = False,
+        reuse_tensors: bool = False,
+        sync_epoch_when_averaging: bool = False,
+        parameter_names: Optional[Sequence[str]] = None,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        status_loglevel: int = logging.DEBUG,
+        **kwargs,
+    ):
+        average_opt_statistics = tuple(average_opt_statistics)
+        assert all(isinstance(key, str) for key in average_opt_statistics)
+        if offload_optimizer and reuse_tensors:
+            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if custom_gradients and not offload_optimizer:
+            logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+
+        self.status_loglevel = status_loglevel
+        self.reuse_tensors = reuse_tensors
+        self.offload_optimizer = offload_optimizer
+        self.custom_gradients = custom_gradients
+
+        self._main_parameters, self._parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self.optimizer, self.scheduler = self._init_components(
+            param_groups, optimizer, scheduler, initialize_optimizer
+        )
+        self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
+        self.sync_epoch_when_averaging = sync_epoch_when_averaging
+        self.local_epoch = 0
+
+        self.step_executor = ThreadPoolExecutor(max_workers=1)
+        self.finished_optimizer_step = Event()
+        self.finished_averaging_round = Event()
+        self.pending_update = Future()
+        self.pending_update.set_result(None)
+
+        super().__init__(
+            dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
+        )
+
+    @staticmethod
+    def _check_params(
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]],
+        parameter_names: Optional[Sequence[str]],
+    ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
+        """Get and verify parameters, groups and names"""
+        if param_groups is None:
+            assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
+            param_groups = optimizer.param_groups
+        param_groups = tuple(param_groups)
+        if all(isinstance(p, torch.Tensor) for p in param_groups):
+            param_groups = (dict(params=param_groups),)
+        for group in param_groups:
+            assert isinstance(group, dict) and group.get("params") is not None
+            assert all(isinstance(p, torch.Tensor) for p in group["params"])
+        parameters = tuple(chain(*(group["params"] for group in param_groups)))
+        if parameter_names is None:
+            parameter_names = tuple(i for i in range(len(parameters)))
+        parameter_names = tuple(nested_flatten(parameter_names))
+        assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
+        assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        return param_groups, parameters, parameter_names
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+        """Create a new tensor for averaging or reuse the existing one"""
+        if self.reuse_tensors:
+            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+            if not source_tensor.is_shared():
+                source_tensor.share_memory_()
+            return source_tensor
+        else:
+            averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
+            return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
+
+    def _init_components(
+        self,
+        param_groups: ParamGroups,
+        optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
+        scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
+        initialize_optimizer: Optional[bool],
+    ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
+        """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
+        assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
+        optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
+        scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
+        if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
+            raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
+        if self.offload_optimizer and not optimizer_is_factory:
+            raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
+
+        # create optimizer
+        if optimizer_is_factory:
+            if self.offload_optimizer:
+                for param in self._averaged_parameters:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
+
+                next_index = 0
+                param_groups_for_optimizer = []
+                for param_group in param_groups:
+                    num_params = len(param_group["params"])
+                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
+                    next_index += num_params
+                assert next_index == len(self._averaged_parameters)
+
+            else:
+                param_groups_for_optimizer = param_groups
+            optimizer = optimizer_or_factory(param_groups_for_optimizer)
+        else:
+            optimizer = optimizer_or_factory
+
+        # optionally initialize optimizer state dict
+        if initialize_optimizer is None:
+            initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
+            logger.log(
+                self.status_loglevel,
+                "Initializing optimizer manually since it has no tensors in state dict"
+                "To override this, please provide initialize_optimizer=False",
+            )
+
+        if initialize_optimizer:
+            initialize_optimizer_state_(optimizer)  # note: this will run one optimizer step!
+
+        # create LR scheduler
+        if scheduler_is_factory:
+            assert callable(scheduler_or_factory)
+            scheduler = scheduler_or_factory(optimizer)
+        else:
+            scheduler = scheduler_or_factory
+
+        # verify optimizer and scheduler
+        assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
+        if self.offload_optimizer or self.reuse_tensors:
+            for param_group in optimizer.param_groups:
+                for param in param_group["params"]:
+                    assert param.is_shared()
+        assert isinstance(scheduler, (LRSchedulerBase, type(None)))
+        if scheduler is not None:
+            assert scheduler.optimizer == optimizer
+        return optimizer, scheduler
+
+    def _local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
+        for param_group in self.optimizer.param_groups:
+            yield from param_group["params"]
+        for stats in self.opt_keys_for_averaging:
+            for param_group in self.optimizer.param_groups:
+                for param in param_group["params"]:
+                    yield self.optimizer.state[param][stats]
+        yield from self.extra_tensors
+
+    @torch.no_grad()
+    def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
+        """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
+        assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
+        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
+        assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
+
+        local_tensors = tuple(self._local_tensors())
+        local_non_parameters = local_tensors[len(self._averaged_parameters) :]
+        averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
+        averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
+        averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
+
+        assert len(averaged_tensors) == len(local_tensors)
+        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+            assert local_tensor.shape == averaged_tensor.shape
+            if averaged_tensor.grad is not None:
+                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+
+        return averaged_tensors
+
+    def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
+        """Get CompressionInfo for each state tensor, accounting for its role and specification"""
+        tensor_infos = []
+        for param, param_name in zip(self._main_parameters, self._parameter_names):
+            tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
+        for stats_name in self.opt_keys_for_averaging:
+            opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(opt_parameters) == len(self._parameter_names)
+            for param, param_name in zip(opt_parameters, self._parameter_names):
+                tensor_infos.append(
+                    CompressionInfo.from_tensor(
+                        self.optimizer.state[param][stats_name],
+                        key=(param_name, stats_name),
+                        role=TensorRole.OPTIMIZER,
+                    )
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
+        return tuple(tensor_infos)
+
+    def step(
+        self,
+        wait_for_delayed_update: bool = None,
+        apply_delayed_updates: bool = True,
+        increment_epoch: bool = False,
+        optimizer_step: bool = False,
+        zero_grad: bool = False,
+        delay_optimizer_step: bool = False,
+        averaging_round: bool = False,
+        delay_averaging: Optional[bool] = None,
+        averaging_kwargs: Optional[Dict[str, Any]] = None,
+    ):
+        """
+        Perform one or several possible actions, depending on the specified keyword args.
+        The actions will be performed in the same order as specified below:
+
+        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+          by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
+        :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
+        :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
+        :param zero_grad: if True, reset local gradients after performing optimizer step
+        :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
+        :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param delay_averaging: if True, perform averaging in background and apply results in a future step
+          by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        """
+        if delay_averaging is None:
+            delay_averaging = delay_optimizer_step
+        if wait_for_delayed_update is None:
+            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
+        if optimizer_step or averaging_round or zero_grad:
+            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
+        if delay_optimizer_step:
+            assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
+            assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
+        if averaging_kwargs and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        output = None
+
+        if wait_for_delayed_update:
+            if not self.pending_update.done():
+                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                output = self.pending_update.result()
+
+        if self.pending_update.done() and self.pending_update.exception():
+            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+
+        if apply_delayed_updates:
+            if self.finished_averaging_round.is_set():
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Received results from background averaging round")
+                self.finished_averaging_round.clear()
+
+            if self.finished_optimizer_step.is_set():
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                self.finished_optimizer_step.clear()
+
+        if increment_epoch:
+            self.local_epoch += 1
+            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
+            self._update_scheduler()
+
+        if optimizer_step or zero_grad or averaging_round:
+            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
+
+            if self.offload_optimizer and not self.custom_gradients:
+                self._load_local_grads_into_optimizer_()
+
+            self.pending_update = self.step_executor.submit(
+                self._do,
+                optimizer_step,
+                zero_grad,
+                averaging_round,
+                **averaging_kwargs or {},
+            )
+
+            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+                self.finished_optimizer_step.wait()
+                self.finished_optimizer_step.clear()
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Finished optimizer step")
+
+            if averaging_round and not delay_averaging:
+                self.finished_averaging_round.wait()
+                self.finished_averaging_round.clear()
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Finished averaging round")
+
+            if not delay_averaging:
+                try:
+                    output = self.pending_update.result()
+                finally:
+                    self.finished_averaging_round.clear()
+                    self.finished_optimizer_step.clear()
+        return output
+
+    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+        """
+        Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
+        This method is meant to be called in the background executor.
+        """
+        try:
+            if optimizer_step:
+                logger.log(self.status_loglevel, f"Running optimizer step")
+                self.optimizer.step()
+            if zero_grad:
+                logger.log(self.status_loglevel, f"Running zero grad")
+                self.optimizer.zero_grad()
+                if self.offload_optimizer:
+                    for parameter in self._main_parameters:
+                        if parameter.grad is not None:
+                            parameter.grad.zero_()
+
+            self.finished_optimizer_step.set()
+
+            if averaging_round:
+                if not self.reuse_tensors:
+                    self._load_local_tensors_into_averager_()
+                try:
+                    gathered = super().step(gather=self.local_epoch, **kwargs)
+                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                    self.finished_averaging_round.set()
+                    gathered = {}
+
+                self.finished_averaging_round.set()
+
+                if self.sync_epoch_when_averaging:
+                    old_epoch = self.local_epoch
+                    for peer_epoch in gathered.values():
+                        self.local_epoch = max(self.local_epoch, peer_epoch)
+                    if self.local_epoch != old_epoch:
+                        logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
+                        self._update_scheduler()
+
+        except Exception as e:
+            logger.exception(e)
+            self.finished_optimizer_step.set()
+            self.finished_averaging_round.set()
+
+    @torch.no_grad()
+    def _load_local_grads_into_optimizer_(self):
+        """Copy local gradients into the gradient buffers of the offloaded optimizer"""
+        assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
+        opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+            if main_param.grad is not None:
+                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_optimizer_results_(self):
+        """Copy parameters from offloaded optimizer to the main model"""
+        assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
+        with self.lock_averaged_tensors:
+            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
+            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
+                main_param.copy_(offloaded_param, non_blocking=True)
+
+    @torch.no_grad()
+    def _load_local_tensors_into_averager_(self):
+        """Copy local tensors into the averaging buffers"""
+        assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
+        with self.get_tensors() as averaged_tensors:
+            for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
+                averaged_tensor.copy_(local_tensor, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_averaging_results_(self):
+        """Copy averaged tensors into their respective local tensors"""
+        assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        with self.get_tensors() as averaged_tensors:
+            local_tensors = list(self._local_tensors())
+            assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
+            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                local_tensor.copy_(averaged_tensor, non_blocking=True)
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad():
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
+            )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self._parameter_names)
+            ]
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
+
+        metadata = dict(
+            epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
+        )
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
+            logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
+            return
+
+        loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
+        loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
+        if num_parameters_and_extras != len(loaded_parameters_and_extras):
+            logger.error("Failed to load state from peer, received parameters, extras or metadata.")
+            return
+
+        try:
+            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+        except StopIteration:
+            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+            return
+
+        with torch.no_grad():
+            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+                local_param.copy_(loaded_param, non_blocking=True)
+        self.local_epoch = metadata["epoch"]
+        self._update_scheduler()
+
+    def _update_scheduler(self):
+        """Increase the scheduler state until it becomes synchronized with local epoch"""
+        if self.scheduler:
+            while self.scheduler._step_count <= self.local_epoch:
+                self.scheduler.step()
+
+
+def initialize_optimizer_state_(opt: torch.optim.Optimizer):
+    """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
+    flat_params = tuple(param for group in opt.param_groups for param in group["params"])
+    old_grads = []
+    for param in flat_params:
+        old_grads.append(param.grad)
+        param.grad = torch.zeros_like(param)
+    opt.step()
+    for param, old_grad in zip(flat_params, old_grads):
+        param.grad = old_grad
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type="value", value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    """Load a state obtained by dump_optimizer_state back into the optimizer"""
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
+    return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 1 - 1
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)

+ 0 - 0
hivemind/averaging/training.py → hivemind/optim/training_averager.py


+ 23 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,8 @@
 import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -147,3 +148,24 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
             yield item
     finally:
         event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 5 - 1
hivemind/optim/performance_ema.py → hivemind/utils/performance_ema.py

@@ -37,6 +37,10 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
 
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
     @contextmanager
     def pause(self):
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
@@ -44,8 +48,8 @@ class PerformanceEMA:
         try:
             yield
         finally:
-            self.timestamp = time.perf_counter()
             self.paused = was_paused
+            self.reset_timer()
 
     def __repr__(self):
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"

+ 5 - 2
tests/conftest.py

@@ -1,13 +1,13 @@
 import asyncio
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 
 import psutil
 import pytest
 
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.mpfuture import MPFuture
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -33,6 +33,9 @@ def event_loop():
 def cleanup_children():
     yield
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
 
     children = psutil.Process().children(recursive=True)

+ 2 - 2
tests/test_averaging.py

@@ -481,7 +481,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         average_gradients=True,
         average_parameters=True,
@@ -492,7 +492,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         average_gradients=True,
         average_parameters=True,

+ 172 - 0
tests/test_optimizer.py

@@ -0,0 +1,172 @@
+import time
+from functools import partial
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import hivemind
+from hivemind.averaging.control import AveragingStage
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.state_averager import TrainingStateAverager
+
+
+@pytest.mark.forked
+def test_grad_averager():
+    dht1 = hivemind.DHT(start=True)
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager1 = GradientAverager(
+        model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
+    )
+
+    dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager2 = GradientAverager(
+        model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
+    )
+
+    control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
+    control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
+
+    for i in range(10):
+        time.sleep(0.1)
+        if i % 3 == 0:
+            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1.backward()
+            averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
+            model1.zero_grad()
+        else:
+            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2.backward()
+            averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
+            # note: we do not call zero grad here because reuse_grad_buffers=True
+
+    assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
+    peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
+    assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
+    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
+
+    assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
+    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
+
+    averager1.step(control=control1, wait=False)
+    averager2.step(control=control2, wait=False)
+    for step in (control1, control2):
+        step.result()  # wait for all-reduce to finish
+
+    peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
+    peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
+    ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
+    with averager1.use_averaged_gradients():
+        assert torch.allclose(model1.w.grad, ref_average)
+    with averager2.use_averaged_gradients():
+        assert torch.allclose(model2.w.grad, ref_average)
+
+    # after no longer use_averaged_gradients
+    assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
+    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+)
+def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    torch.manual_seed(1337)
+    torch.use_deterministic_algorithms(True)
+    # note: use_deterministic_algorithms does not affect further tests because this test is forked
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        sync_epoch_when_averaging=sync_epoch_when_averaging,
+        average_opt_statistics=("exp_avg_sq",),
+        offload_optimizer=offload_optimizer,
+        reuse_tensors=reuse_tensors,
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+    )
+    avgr2 = TrainingStateAverager(
+        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+    )
+
+    x = torch.ones(2)
+
+    for step in range(20):
+        F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
+        avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
+
+        F.mse_loss(model2(x), -torch.ones(3)).backward()
+        avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
+
+    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
+    assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
+    assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
+
+    stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    assert not torch.allclose(stats1, stats2)
+
+    avgr1.step(increment_epoch=True)
+
+    avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+    avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+
+    avgr1.step(wait_for_delayed_update=True)
+    avgr2.step(wait_for_delayed_update=True)
+
+    assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
+    assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert avgr1.local_epoch == 2
+    assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.SGD, lr=0.1),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+    )
+
+    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+
+    avgr2.local_epoch = 1337
+    model2.weight.data[...] = 42
+    time.sleep(0.1)
+
+    avgr1.load_state_from_peers()
+    assert avgr1.local_epoch == 1337
+    assert torch.all(model1.weight == 42).item()
+    assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)

+ 19 - 1
tests/test_util_modules.py

@@ -11,7 +11,6 @@ import torch
 
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -28,8 +27,10 @@ from hivemind.utils.asyncio import (
     attach_event_on_finished,
     azip,
     cancel_and_wait,
+    enter_asynchronously,
 )
 from hivemind.utils.mpfuture import InvalidStateError
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 @pytest.mark.forked
@@ -538,6 +539,23 @@ async def test_cancel_and_wait():
     assert not await cancel_and_wait(task_with_error)
 
 
+@pytest.mark.asyncio
+async def test_async_context():
+    lock = mp.Lock()
+
+    async def coro1():
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.2)
+
+    async def coro2():
+        await asyncio.sleep(0.1)
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.1)
+
+    await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
+    # running this without enter_asynchronously would deadlock the event loop
+
+
 def test_batch_tensor_descriptor_msgpack():
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))