Browse Source

Merge branch 'master' into rfc_optimizer

justheuristic 3 years ago
parent
commit
abbea26c9c

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.png


+ 2 - 1
hivemind/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
@@ -16,6 +16,7 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
     DecentralizedSGD,
+    TrainingAverager,
 )
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 10 - 11
hivemind/averaging/averager.py

@@ -32,7 +32,15 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -453,7 +461,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
             )
 
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
@@ -505,15 +513,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors
 
 
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
-
     async def rpc_join_group(
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:

+ 1 - 1
hivemind/averaging/control.py

@@ -103,7 +103,7 @@ class StepControl(MPFuture):
     @stage.setter
     @stage.setter
     def stage(self, stage: AveragingStage):
     def stage(self, stage: AveragingStage):
         if stage == AveragingStage.RUNNING_ALLREDUCE:
         if stage == AveragingStage.RUNNING_ALLREDUCE:
-            self.can_modify = False
+            self.began_allreduce = True
         self._shared_buffer[StepControl._STAGE] = stage.value
         self._shared_buffer[StepControl._STAGE] = stage.value
 
 
     @property
     @property

+ 3 - 322
hivemind/dht/__init__.py

@@ -4,7 +4,7 @@ Hivemind DHT is based on Kademlia [1] with added support for improved bulk store
 
 
 The code is organized as follows:
 The code is organized as follows:
 
 
- * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
+ * **class DHT (dht.py)** - high-level class for model training. Runs DHTNode in a background process.
  * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
  * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
  * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
  * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
@@ -12,327 +12,8 @@ The code is organized as follows:
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
 """
-from __future__ import annotations
-
-import asyncio
-import multiprocessing as mp
-import os
-from functools import partial
-from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
-
-from multiaddr import Multiaddr
 
 
+from hivemind.dht.dht import DHT
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
-from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
-from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
-
-logger = get_logger(__name__)
-
-ReturnType = TypeVar("ReturnType")
-
-
-class DHT(mp.Process):
-    """
-    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
-    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
-    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
-
-    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
-    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
-    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
-      (but no more than one per key)
-    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
-    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
-      The validators will be combined using the CompositeValidator class. It merges them when possible
-      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
-    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
-    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
-    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
-    """
-
-    _node: DHTNode
-
-    def __init__(
-        self,
-        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
-        *,
-        start: bool,
-        p2p: Optional[P2P] = None,
-        daemon: bool = True,
-        num_workers: int = DEFAULT_NUM_WORKERS,
-        record_validators: Iterable[RecordValidatorBase] = (),
-        shutdown_timeout: float = 3,
-        await_ready: bool = True,
-        **kwargs,
-    ):
-        self._parent_pid = os.getpid()
-        super().__init__()
-
-        if not (
-            initial_peers is None
-            or (
-                isinstance(initial_peers, Sequence)
-                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
-            )
-        ):
-            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
-        self.initial_peers = initial_peers
-        self.kwargs = kwargs
-        self.num_workers = num_workers
-
-        self._record_validator = CompositeValidator(record_validators)
-        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
-        self.shutdown_timeout = shutdown_timeout
-        self._ready = MPFuture()
-        self.daemon = daemon
-
-        # These values will be fetched from the child process when requested
-        self._peer_id = None
-        self._client_mode = None
-        self._p2p_replica = None
-
-        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
-
-        if start:
-            self.run_in_background(await_ready=await_ready)
-
-    def run(self) -> None:
-        """Serve DHT forever. This function will not return until DHT node is shut down"""
-
-        loop = switch_to_uvloop()
-        pipe_semaphore = asyncio.Semaphore(value=0)
-        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
-
-        async def _run():
-            try:
-                if self._daemon_listen_maddr is not None:
-                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                else:
-                    replicated_p2p = None
-
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    p2p=replicated_p2p,
-                    **self.kwargs,
-                )
-            except Exception as e:
-                # Loglevel is DEBUG since normally the exception is propagated to the caller
-                logger.debug(e, exc_info=True)
-                self._ready.set_exception(e)
-                return
-            self._ready.set_result(None)
-
-            while True:
-                try:
-                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
-                except asyncio.TimeoutError:
-                    pass
-                if not self._inner_pipe.poll():
-                    continue
-                try:
-                    method, args, kwargs = self._inner_pipe.recv()
-                except (OSError, ConnectionError, RuntimeError) as e:
-                    logger.exception(e)
-                    await asyncio.sleep(self._node.protocol.wait_timeout)
-                    continue
-                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                if method == "_shutdown":
-                    await task
-                    break
-
-        loop.run_until_complete(_run())
-
-    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
-        """
-        Starts DHT in a background process. if await_ready, this method will wait until background dht
-        is ready to process incoming requests or for :timeout: seconds max.
-        """
-        self.start()
-        if await_ready:
-            self.wait_until_ready(timeout)
-
-    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
-        self._ready.result(timeout=timeout)
-
-    def shutdown(self) -> None:
-        """Shut down a running dht process"""
-        if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [], {}))
-            self.join(self.shutdown_timeout)
-            if self.is_alive():
-                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
-                self.terminate()
-
-    async def _shutdown(self):
-        await self._node.shutdown()
-
-    def get(
-        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
-    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
-        """
-        Search for a key across DHT and return either first or latest entry (if found).
-        :param key: same key as in node.store(...)
-        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
-        :returns: (value, expiration time); if value was not found, returns None
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
-        return future if return_future else future.result()
-
-    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
-        try:
-            result = await self._node.get(key, latest=latest, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey] = None,
-        return_future: bool = False,
-        **kwargs,
-    ) -> Union[bool, MPFuture]:
-        """
-        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
-
-        :param key: msgpack-serializable key to be associated with value until expiration.
-        :param value: msgpack-serializable value to be stored under a given key until expiration.
-        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
-        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: True if store succeeds, False if it fails (due to no response or newer value)
-        """
-        future = MPFuture()
-        self._outer_pipe.send(
-            (
-                "_store",
-                [],
-                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
-            )
-        )
-        return future if return_future else future.result()
-
-    async def _store(
-        self,
-        key: DHTKey,
-        value: DHTValue,
-        expiration_time: DHTExpiration,
-        subkey: Optional[Subkey],
-        future: MPFuture,
-        **kwargs,
-    ):
-        try:
-            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
-            if not future.done():
-                future.set_result(result)
-        except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
-            raise
-
-    def run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
-    ) -> Union[ReturnType, MPFuture[ReturnType]]:
-        """
-        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
-         for running custom functions DHT for special cases (e.g. declare experts, beam search)
-
-        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
-        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-        :returns: coroutine outputs or MPFuture for these outputs
-        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
-          DHT fields made by this coroutine will not be accessible from the host process.
-        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
-          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
-        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
-        """
-        future = MPFuture()
-        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
-        return future if return_future else future.result()
-
-    async def _run_coroutine(
-        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
-    ):
-        try:
-            future.set_result(await coro(self, self._node))
-        except BaseException as e:
-            logger.exception("Caught an exception when running a coroutine:")
-            future.set_exception(e)
-
-    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self._ready.done():
-            raise RuntimeError(
-                "Can't append new validators before the DHT process has started. "
-                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
-            )
-
-        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
-
-    @staticmethod
-    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
-        node.protocol.record_validator.extend(record_validators)
-
-    @property
-    def peer_id(self) -> PeerID:
-        if self._peer_id is None:
-            self._peer_id = self.run_coroutine(DHT._get_peer_id)
-        return self._peer_id
-
-    @staticmethod
-    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
-        return node.peer_id
-
-    @property
-    def client_mode(self) -> bool:
-        if self._client_mode is None:
-            self._client_mode = self.run_coroutine(DHT._get_client_mode)
-        return self._client_mode
-
-    @staticmethod
-    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
-        return node.protocol.client_mode
-
-    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
-        """
-        Get multiaddrs of the current DHT node that should be accessible by other peers.
-
-        :param latest: ask the P2P daemon to refresh the visible multiaddrs
-        """
-
-        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
-
-    @staticmethod
-    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
-        return await node.get_visible_maddrs(latest=latest)
-
-    async def replicate_p2p(self) -> P2P:
-        """
-        Get a replica of a P2P instance used in the DHT process internally.
-        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
-        """
-
-        if self._p2p_replica is None:
-            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
-            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
-        return self._p2p_replica
-
-    @staticmethod
-    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
-        return node.p2p.daemon_listen_maddr
-
-    def __del__(self):
-        if self._parent_pid == os.getpid() and self.is_alive():
-            self.shutdown()

+ 324 - 0
hivemind/dht/dht.py

@@ -0,0 +1,324 @@
+from __future__ import annotations
+
+import asyncio
+import multiprocessing as mp
+import os
+from functools import partial
+from typing import Awaitable, Callable, Iterable, List, Optional, Sequence, TypeVar, Union
+
+from multiaddr import Multiaddr
+
+from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
+from hivemind.dht.routing import DHTKey, DHTValue, Subkey
+from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
+from hivemind.p2p import P2P, PeerID
+from hivemind.utils import MPFuture, get_logger, switch_to_uvloop
+from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration
+
+logger = get_logger(__name__)
+ReturnType = TypeVar("ReturnType")
+
+
+class DHT(mp.Process):
+    """
+    A high-level interface to a hivemind DHT that runs a single DHT node in a background process.
+    * hivemind servers periodically announce their experts via declare_experts (dht_handler.py)
+    * trainers find most suitable experts via RemoteMixtureOfExperts (beam_search.py)
+
+    :param initial_peers: multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
+    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
+    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
+    :param num_workers: declare_experts and get_experts will use up to this many parallel workers
+      (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
+      The validators will be combined using the CompositeValidator class. It merges them when possible
+      (according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
+    :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
+    :param await_ready: if True, the constructor waits until the DHT process is ready to process incoming requests
+    :param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
+    """
+
+    _node: DHTNode
+
+    def __init__(
+        self,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
+        *,
+        start: bool,
+        p2p: Optional[P2P] = None,
+        daemon: bool = True,
+        num_workers: int = DEFAULT_NUM_WORKERS,
+        record_validators: Iterable[RecordValidatorBase] = (),
+        shutdown_timeout: float = 3,
+        await_ready: bool = True,
+        **kwargs,
+    ):
+        self._parent_pid = os.getpid()
+        super().__init__()
+
+        if not (
+            initial_peers is None
+            or (
+                isinstance(initial_peers, Sequence)
+                and all(isinstance(item, (Multiaddr, str)) for item in initial_peers)
+            )
+        ):
+            raise TypeError("initial_peers should be of type Optional[Sequence[Union[Multiaddr, str]]]")
+        self.initial_peers = initial_peers
+        self.kwargs = kwargs
+        self.num_workers = num_workers
+
+        self._record_validator = CompositeValidator(record_validators)
+        self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
+        self.shutdown_timeout = shutdown_timeout
+        self._ready = MPFuture()
+        self.daemon = daemon
+
+        # These values will be fetched from the child process when requested
+        self._peer_id = None
+        self._client_mode = None
+        self._p2p_replica = None
+
+        self._daemon_listen_maddr = p2p.daemon_listen_maddr if p2p is not None else None
+
+        if start:
+            self.run_in_background(await_ready=await_ready)
+
+    def run(self) -> None:
+        """Serve DHT forever. This function will not return until DHT node is shut down"""
+
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts DHT in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
+
+    def shutdown(self) -> None:
+        """Shut down a running dht process"""
+        if self.is_alive():
+            self._outer_pipe.send(("_shutdown", [], {}))
+            self.join(self.shutdown_timeout)
+            if self.is_alive():
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
+                self.terminate()
+
+    async def _shutdown(self):
+        await self._node.shutdown()
+
+    def get(
+        self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
+    ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
+        """
+        Search for a key across DHT and return either first or latest entry (if found).
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :param kwargs: parameters forwarded to DHTNode.get_many_by_id
+        :returns: (value, expiration time); if value was not found, returns None
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_get", [], dict(key=key, latest=latest, future=future, **kwargs)))
+        return future if return_future else future.result()
+
+    async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
+        try:
+            result = await self._node.get(key, latest=latest, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey] = None,
+        return_future: bool = False,
+        **kwargs,
+    ) -> Union[bool, MPFuture]:
+        """
+        Find num_replicas best nodes to store (key, value) and store it there until expiration time.
+
+        :param key: msgpack-serializable key to be associated with value until expiration.
+        :param value: msgpack-serializable value to be stored under a given key until expiration.
+        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
+        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        future = MPFuture()
+        self._outer_pipe.send(
+            (
+                "_store",
+                [],
+                dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey, future=future, **kwargs),
+            )
+        )
+        return future if return_future else future.result()
+
+    async def _store(
+        self,
+        key: DHTKey,
+        value: DHTValue,
+        expiration_time: DHTExpiration,
+        subkey: Optional[Subkey],
+        future: MPFuture,
+        **kwargs,
+    ):
+        try:
+            result = await self._node.store(key, value, expiration_time, subkey=subkey, **kwargs)
+            if not future.done():
+                future.set_result(result)
+        except BaseException as e:
+            if not future.done():
+                future.set_exception(e)
+            raise
+
+    def run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], return_future: bool = False
+    ) -> Union[ReturnType, MPFuture[ReturnType]]:
+        """
+        Execute an asynchronous function on a DHT participant and return results. This is meant as an interface
+         for running custom functions DHT for special cases (e.g. declare experts, beam search)
+
+        :param coro: async function to be executed. Receives 2 arguments: this DHT daemon and a running DHTNode
+        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
+        :returns: coroutine outputs or MPFuture for these outputs
+        :note: the coroutine will be executed inside the DHT process. As such, any changes to global variables or
+          DHT fields made by this coroutine will not be accessible from the host process.
+        :note: all time-consuming operations in coro should be asynchronous (e.g. asyncio.sleep instead of time.sleep)
+          or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
+        :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
+        """
+        future = MPFuture()
+        self._outer_pipe.send(("_run_coroutine", [], dict(coro=coro, future=future)))
+        return future if return_future else future.result()
+
+    async def _run_coroutine(
+        self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
+    ):
+        try:
+            future.set_result(await coro(self, self._node))
+        except BaseException as e:
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
+
+    def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
+        if not self._ready.done():
+            raise RuntimeError(
+                "Can't append new validators before the DHT process has started. "
+                "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
+            )
+
+        self.run_coroutine(partial(DHT._add_validators, record_validators=record_validators))
+
+    @staticmethod
+    async def _add_validators(_dht: DHT, node: DHTNode, record_validators: Iterable[RecordValidatorBase]) -> None:
+        node.protocol.record_validator.extend(record_validators)
+
+    @property
+    def peer_id(self) -> PeerID:
+        if self._peer_id is None:
+            self._peer_id = self.run_coroutine(DHT._get_peer_id)
+        return self._peer_id
+
+    @staticmethod
+    async def _get_peer_id(_dht: DHT, node: DHTNode) -> PeerID:
+        return node.peer_id
+
+    @property
+    def client_mode(self) -> bool:
+        if self._client_mode is None:
+            self._client_mode = self.run_coroutine(DHT._get_client_mode)
+        return self._client_mode
+
+    @staticmethod
+    async def _get_client_mode(_dht: DHT, node: DHTNode) -> bool:
+        return node.protocol.client_mode
+
+    def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
+        """
+        Get multiaddrs of the current DHT node that should be accessible by other peers.
+
+        :param latest: ask the P2P daemon to refresh the visible multiaddrs
+        """
+
+        return self.run_coroutine(partial(DHT._get_visible_maddrs, latest=latest))
+
+    @staticmethod
+    async def _get_visible_maddrs(_dht: DHT, node: DHTNode, latest: bool = False) -> List[Multiaddr]:
+        return await node.get_visible_maddrs(latest=latest)
+
+    async def replicate_p2p(self) -> P2P:
+        """
+        Get a replica of a P2P instance used in the DHT process internally.
+        The replica uses the same P2P daemon as the DHT and only works while DHT is alive.
+        """
+
+        if self._p2p_replica is None:
+            daemon_listen_maddr = self.run_coroutine(DHT._get_p2p_daemon_listen_maddr)
+            self._p2p_replica = await P2P.replicate(daemon_listen_maddr)
+        return self._p2p_replica
+
+    @staticmethod
+    async def _get_p2p_daemon_listen_maddr(_dht: DHT, node: DHTNode) -> Multiaddr:
+        return node.p2p.daemon_listen_maddr
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 1 - 1
hivemind/dht/routing.py

@@ -10,7 +10,7 @@ from itertools import chain
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
 
 
 from hivemind.p2p import PeerID
 from hivemind.p2p import PeerID
-from hivemind.utils import MSGPackSerializer, get_dht_time
+from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
 
 
 DHTKey = Subkey = DHTValue = Any
 DHTKey = Subkey = DHTValue = Any
 BinaryDHTID = BinaryDHTValue = bytes
 BinaryDHTID = BinaryDHTValue = bytes

+ 1 - 0
hivemind/optim/__init__.py

@@ -3,3 +3,4 @@ from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 
 import torch.optim
 import torch.optim
 
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -9,14 +9,14 @@ import numpy as np
 import torch
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 0 - 0
hivemind/optim/experimental/__init__.py


+ 219 - 0
hivemind/optim/experimental/grad_averager.py

@@ -0,0 +1,219 @@
+import contextlib
+from typing import Iterable, Iterator, Optional
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.utils import DHTExpiration, get_logger
+
+logger = get_logger(__name__)
+
+
+class GradientAverager(DecentralizedAverager):
+    """
+    An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
+    GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
+
+    GradientAverager manages three sets of buffers:
+    (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
+        These tensors are typically stored on device and updated by torch autograd
+    (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
+      - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
+        which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
+    (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
+
+
+    Example:
+
+    >>> model = SuchModelMuchLayers()
+    >>> opt = torch.optim.Adam(model.parameters())
+    >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
+    >>> next_step_time = hivemind.get_dht_time() + 60   # runs global steps every 60 seconds
+    >>> next_step_control = None
+    >>> while True:
+    >>>    # accumulate as many gradients as you can before next_step_time
+    >>>    loss = compute_loss(model, batch_size=32)
+    >>>    loss.backward()
+    >>>    grad_averager.accumulate_grads_(batch_size=32)
+    >>>    # [optional] next step in 5 seconds, start looking for peers in advance
+    >>>    if next_step_time - hivemind.get_dht_time() <= 5
+    >>>        next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
+    >>>    # aggregate gradients and perform optimizer step
+    >>>    if hivemind.get_dht_time() >= next_step_time:
+    >>>        grad_averager.step(control=next_step_control)
+    >>>        with grad_averager.use_averaged_gradients():  # this will fill param.grads with aggregated gradients
+    >>>            opt.step()  # update model parameters using averaged gradients
+    >>>        grad_averager.reset_accumulated_grads_()  # prepare for next step
+    >>>        next_step_time = hivemind.get_dht_time() + 60
+    >>>        next_step_control = None
+
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self._parameters = tuple(parameters)
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.warn = warn
+        self.local_samples_accumulated = 0
+        self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        self._local_accumulators = None
+        if not reuse_grad_buffers:
+            self._local_accumulators = tuple(
+                torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
+            )
+        self._accumulators_used_in_step = False
+        self._new_averaged_grads = False
+
+        with torch.no_grad():
+            averaged_grads = tuple(
+                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+            )
+        super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
+
+    def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
+        """gradient buffers associated with parameters"""
+        for param in self._parameters:
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            yield param.grad
+
+    @torch.no_grad()
+    def _grad_accumulators(self) -> Iterator[torch.Tensor]:
+        """averager-based gradient accumulators"""
+        assert (self._local_accumulators is None) == self.reuse_grad_buffers
+        yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """add current gradients to local grad accumulators (if used)"""
+        if self._accumulators_used_in_step and self.warn:
+            logger.warning(
+                "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)."
+            )
+            self._accumulators_used_in_step = False  # warn once per round
+        if self._anchor_batch_size is None:
+            # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
+            self._anchor_batch_size = batch_size
+        self.local_samples_accumulated += batch_size
+        self.local_times_accumulated += 1
+        if self.reuse_grad_buffers:
+            pass  # user is responsible for accumulating gradients in .grad buffers
+        else:
+            alpha = float(batch_size) / self._anchor_batch_size
+            for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        weight: Optional[float] = None,
+        reset_accumulators: bool = True,
+        control: Optional[StepControl] = None,
+        wait: bool = True,
+        **kwargs,
+    ):
+        """
+        Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
+
+        :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
+        :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
+        :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
+        """
+        if control is None:
+            control = self.schedule_step(**kwargs)
+        elif len(kwargs) > 0:
+            RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
+        assert not control.triggered, f"This {type(control)} instance was already used."
+        self._load_accumulators_into_averager_()
+        self._accumulators_used_in_step = True
+        self._new_averaged_grads = True
+
+        control.weight = self.local_samples_accumulated if weight is None else weight
+        if reset_accumulators:
+            self.reset_accumulated_grads_()
+
+        control.allow_allreduce()
+        return control.result() if wait else control
+
+    @torch.no_grad()
+    def _load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+            self._new_averaged_grads = False  # warn once per round
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
+        with self.get_tensors() as averaged_grads:
+            for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
+                averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        """reset averager-internal gradient accumulators and the denominator"""
+        self._accumulators_used_in_step = False
+        self.local_samples_accumulated = self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        for grad_buf in self._grad_accumulators():
+            grad_buf.zero_()
+
+    @contextlib.contextmanager
+    @torch.no_grad()
+    def use_averaged_gradients(self):
+        self._new_averaged_grads = False
+        with self.get_tensors() as averaged_grads:
+            try:
+                assert len(averaged_grads) == len(self._parameters)
+                old_grads = [param.grad for param in self._parameters]
+                for param, new_grad in zip(self._parameters, averaged_grads):
+                    param.grad = new_grad
+                yield
+            finally:
+                for param, old_grad in zip(self._parameters, old_grads):
+                    param.grad = old_grad

+ 569 - 0
hivemind/optim/experimental/state_averager.py

@@ -0,0 +1,569 @@
+""" An extension of averager that supports common optimization use cases. """
+import logging
+from asyncio import Future
+from concurrent.futures import ThreadPoolExecutor
+from itertools import chain
+from threading import Event
+from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
+
+import torch
+
+import hivemind
+from hivemind import nested_compare
+from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+
+logger = get_logger(__name__)
+
+
+Parameters = Iterable[torch.Tensor]
+ParamGroups = Iterable[Dict[str, Any]]
+TorchOptimizer = torch.optim.Optimizer
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
+SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
+
+
+class TrainingStateAverager(DecentralizedAverager):
+    """
+    An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
+    and any other variables that define the local training state (e.g. batchnorm moving averages).
+    TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
+
+    The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
+    If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
+
+    Example:
+
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
+    >>> avgr.load_state_from_peers()
+    >>> for i, batch in enumerate(training_dataloader):
+    >>>     loss = compute_loss(model, batch)
+    >>>     loss.backward()
+    >>>     avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
+
+    :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
+      it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
+      TrainingStateAverager.step(..., optimizer_step=True)
+
+    :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
+    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
+    :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
+    :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
+      state tensors. If False, user must make sure that all tensors are pre-initialized at init.
+      By default, initialize optimizer unless it already has some state tensors to begin with.
+    :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
+    :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
+      This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
+    :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
+      For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+    :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
+    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        initialize_optimizer: Optional[bool] = None,
+        offload_optimizer: bool = False,
+        custom_gradients: bool = False,
+        reuse_tensors: bool = False,
+        sync_epoch_when_averaging: bool = False,
+        parameter_names: Optional[Sequence[str]] = None,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        status_loglevel: int = logging.DEBUG,
+        **kwargs,
+    ):
+        average_opt_statistics = tuple(average_opt_statistics)
+        assert all(isinstance(key, str) for key in average_opt_statistics)
+        if offload_optimizer and reuse_tensors:
+            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if custom_gradients and not offload_optimizer:
+            logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+
+        self.status_loglevel = status_loglevel
+        self.reuse_tensors = reuse_tensors
+        self.offload_optimizer = offload_optimizer
+        self.custom_gradients = custom_gradients
+
+        self._main_parameters, self._parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self.optimizer, self.scheduler = self._init_components(
+            param_groups, optimizer, scheduler, initialize_optimizer
+        )
+        self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
+        self.sync_epoch_when_averaging = sync_epoch_when_averaging
+        self.local_epoch = 0
+
+        self.step_executor = ThreadPoolExecutor(max_workers=1)
+        self.finished_optimizer_step = Event()
+        self.finished_averaging_round = Event()
+        self.pending_update = Future()
+        self.pending_update.set_result(None)
+
+        super().__init__(
+            dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
+        )
+
+    @staticmethod
+    def _check_params(
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]],
+        parameter_names: Optional[Sequence[str]],
+    ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
+        """Get and verify parameters, groups and names"""
+        if param_groups is None:
+            assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
+            param_groups = optimizer.param_groups
+        param_groups = tuple(param_groups)
+        if all(isinstance(p, torch.Tensor) for p in param_groups):
+            param_groups = (dict(params=param_groups),)
+        for group in param_groups:
+            assert isinstance(group, dict) and group.get("params") is not None
+            assert all(isinstance(p, torch.Tensor) for p in group["params"])
+        parameters = tuple(chain(*(group["params"] for group in param_groups)))
+        if parameter_names is None:
+            parameter_names = tuple(i for i in range(len(parameters)))
+        parameter_names = tuple(nested_flatten(parameter_names))
+        assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
+        assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        return param_groups, parameters, parameter_names
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+        """Create a new tensor for averaging or reuse the existing one"""
+        if self.reuse_tensors:
+            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+            if not source_tensor.is_shared():
+                source_tensor.share_memory_()
+            return source_tensor
+        else:
+            averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
+            return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
+
+    def _init_components(
+        self,
+        param_groups: ParamGroups,
+        optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
+        scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
+        initialize_optimizer: Optional[bool],
+    ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
+        """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
+        assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
+        optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
+        scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
+        if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
+            raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
+        if self.offload_optimizer and not optimizer_is_factory:
+            raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
+
+        # create optimizer
+        if optimizer_is_factory:
+            if self.offload_optimizer:
+                for param in self._averaged_parameters:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
+
+                next_index = 0
+                param_groups_for_optimizer = []
+                for param_group in param_groups:
+                    num_params = len(param_group["params"])
+                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
+                    next_index += num_params
+                assert next_index == len(self._averaged_parameters)
+
+            else:
+                param_groups_for_optimizer = param_groups
+            optimizer = optimizer_or_factory(param_groups_for_optimizer)
+        else:
+            optimizer = optimizer_or_factory
+
+        # optionally initialize optimizer state dict
+        if initialize_optimizer is None:
+            initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
+            logger.log(
+                self.status_loglevel,
+                "Initializing optimizer manually since it has no tensors in state dict"
+                "To override this, please provide initialize_optimizer=False",
+            )
+
+        if initialize_optimizer:
+            initialize_optimizer_state_(optimizer)  # note: this will run one optimizer step!
+
+        # create LR scheduler
+        if scheduler_is_factory:
+            assert callable(scheduler_or_factory)
+            scheduler = scheduler_or_factory(optimizer)
+        else:
+            scheduler = scheduler_or_factory
+
+        # verify optimizer and scheduler
+        assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
+        if self.offload_optimizer or self.reuse_tensors:
+            for param_group in optimizer.param_groups:
+                for param in param_group["params"]:
+                    assert param.is_shared()
+        assert isinstance(scheduler, (LRSchedulerBase, type(None)))
+        if scheduler is not None:
+            assert scheduler.optimizer == optimizer
+        return optimizer, scheduler
+
+    def _local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
+        for param_group in self.optimizer.param_groups:
+            yield from param_group["params"]
+        for stats in self.opt_keys_for_averaging:
+            for param_group in self.optimizer.param_groups:
+                for param in param_group["params"]:
+                    yield self.optimizer.state[param][stats]
+        yield from self.extra_tensors
+
+    @torch.no_grad()
+    def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
+        """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
+        assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
+        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
+        assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
+
+        local_tensors = tuple(self._local_tensors())
+        local_non_parameters = local_tensors[len(self._averaged_parameters) :]
+        averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
+        averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
+        averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
+
+        assert len(averaged_tensors) == len(local_tensors)
+        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+            assert local_tensor.shape == averaged_tensor.shape
+            if averaged_tensor.grad is not None:
+                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+
+        return averaged_tensors
+
+    def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
+        """Get CompressionInfo for each state tensor, accounting for its role and specification"""
+        tensor_infos = []
+        for param, param_name in zip(self._main_parameters, self._parameter_names):
+            tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
+        for stats_name in self.opt_keys_for_averaging:
+            opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(opt_parameters) == len(self._parameter_names)
+            for param, param_name in zip(opt_parameters, self._parameter_names):
+                tensor_infos.append(
+                    CompressionInfo.from_tensor(
+                        self.optimizer.state[param][stats_name],
+                        key=(param_name, stats_name),
+                        role=TensorRole.OPTIMIZER,
+                    )
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
+        return tuple(tensor_infos)
+
+    def step(
+        self,
+        wait_for_delayed_update: bool = None,
+        apply_delayed_updates: bool = True,
+        increment_epoch: bool = False,
+        optimizer_step: bool = False,
+        zero_grad: bool = False,
+        delay_optimizer_step: bool = False,
+        averaging_round: bool = False,
+        delay_averaging: Optional[bool] = None,
+        averaging_kwargs: Optional[Dict[str, Any]] = None,
+    ):
+        """
+        Perform one or several possible actions, depending on the specified keyword args.
+        The actions will be performed in the same order as specified below:
+
+        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+          by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
+        :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
+        :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
+        :param zero_grad: if True, reset local gradients after performing optimizer step
+        :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
+        :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param delay_averaging: if True, perform averaging in background and apply results in a future step
+          by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        """
+        if delay_averaging is None:
+            delay_averaging = delay_optimizer_step
+        if wait_for_delayed_update is None:
+            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
+        if optimizer_step or averaging_round or zero_grad:
+            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
+        if delay_optimizer_step:
+            assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
+            assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
+        if averaging_kwargs and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        output = None
+
+        if wait_for_delayed_update:
+            if not self.pending_update.done():
+                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                output = self.pending_update.result()
+
+        if self.pending_update.done() and self.pending_update.exception():
+            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+
+        if apply_delayed_updates:
+            if self.finished_averaging_round.is_set():
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Received results from background averaging round")
+                self.finished_averaging_round.clear()
+
+            if self.finished_optimizer_step.is_set():
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                self.finished_optimizer_step.clear()
+
+        if increment_epoch:
+            self.local_epoch += 1
+            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
+            self._update_scheduler()
+
+        if optimizer_step or zero_grad or averaging_round:
+            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
+
+            if self.offload_optimizer and not self.custom_gradients:
+                self._load_local_grads_into_optimizer_()
+
+            self.pending_update = self.step_executor.submit(
+                self._do,
+                optimizer_step,
+                zero_grad,
+                averaging_round,
+                **averaging_kwargs or {},
+            )
+
+            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+                self.finished_optimizer_step.wait()
+                self.finished_optimizer_step.clear()
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Finished optimizer step")
+
+            if averaging_round and not delay_averaging:
+                self.finished_averaging_round.wait()
+                self.finished_averaging_round.clear()
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Finished averaging round")
+
+            if not delay_averaging:
+                try:
+                    output = self.pending_update.result()
+                finally:
+                    self.finished_averaging_round.clear()
+                    self.finished_optimizer_step.clear()
+        return output
+
+    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+        """
+        Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
+        This method is meant to be called in the background executor.
+        """
+        try:
+            if optimizer_step:
+                logger.log(self.status_loglevel, f"Running optimizer step")
+                self.optimizer.step()
+            if zero_grad:
+                logger.log(self.status_loglevel, f"Running zero grad")
+                self.optimizer.zero_grad()
+                if self.offload_optimizer:
+                    for parameter in self._main_parameters:
+                        if parameter.grad is not None:
+                            parameter.grad.zero_()
+
+            self.finished_optimizer_step.set()
+
+            if averaging_round:
+                if not self.reuse_tensors:
+                    self._load_local_tensors_into_averager_()
+                try:
+                    gathered = super().step(gather=self.local_epoch, **kwargs)
+                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                    self.finished_averaging_round.set()
+                    gathered = {}
+
+                self.finished_averaging_round.set()
+
+                if self.sync_epoch_when_averaging:
+                    old_epoch = self.local_epoch
+                    for peer_epoch in gathered.values():
+                        self.local_epoch = max(self.local_epoch, peer_epoch)
+                    if self.local_epoch != old_epoch:
+                        logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
+                        self._update_scheduler()
+
+        except Exception as e:
+            logger.exception(e)
+            self.finished_optimizer_step.set()
+            self.finished_averaging_round.set()
+
+    @torch.no_grad()
+    def _load_local_grads_into_optimizer_(self):
+        """Copy local gradients into the gradient buffers of the offloaded optimizer"""
+        assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
+        opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+            if main_param.grad is not None:
+                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_optimizer_results_(self):
+        """Copy parameters from offloaded optimizer to the main model"""
+        assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
+        with self.lock_averaged_tensors:
+            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
+            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
+                main_param.copy_(offloaded_param, non_blocking=True)
+
+    @torch.no_grad()
+    def _load_local_tensors_into_averager_(self):
+        """Copy local tensors into the averaging buffers"""
+        assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
+        with self.get_tensors() as averaged_tensors:
+            for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
+                averaged_tensor.copy_(local_tensor, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_averaging_results_(self):
+        """Copy averaged tensors into their respective local tensors"""
+        assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        with self.get_tensors() as averaged_tensors:
+            local_tensors = list(self._local_tensors())
+            assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
+            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                local_tensor.copy_(averaged_tensor, non_blocking=True)
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad():
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
+            )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self._parameter_names)
+            ]
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
+
+        metadata = dict(
+            epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
+        )
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
+            logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
+            return
+
+        loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
+        loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
+        if num_parameters_and_extras != len(loaded_parameters_and_extras):
+            logger.error("Failed to load state from peer, received parameters, extras or metadata.")
+            return
+
+        try:
+            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+        except StopIteration:
+            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+            return
+
+        with torch.no_grad():
+            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+                local_param.copy_(loaded_param, non_blocking=True)
+        self.local_epoch = metadata["epoch"]
+        self._update_scheduler()
+
+    def _update_scheduler(self):
+        """Increase the scheduler state until it becomes synchronized with local epoch"""
+        if self.scheduler:
+            while self.scheduler._step_count <= self.local_epoch:
+                self.scheduler.step()
+
+
+def initialize_optimizer_state_(opt: torch.optim.Optimizer):
+    """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
+    flat_params = tuple(param for group in opt.param_groups for param in group["params"])
+    old_grads = []
+    for param in flat_params:
+        old_grads.append(param.grad)
+        param.grad = torch.zeros_like(param)
+    opt.step()
+    for param, old_grad in zip(flat_params, old_grads):
+        param.grad = old_grad
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type="value", value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    """Load a state obtained by dump_optimizer_state back into the optimizer"""
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
+    return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 1 - 1
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 
 import torch
 import torch
 
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils import get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 0 - 0
hivemind/averaging/training.py → hivemind/optim/training_averager.py


+ 23 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,8 @@
 import asyncio
 import asyncio
 import concurrent.futures
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 
 import uvloop
 import uvloop
 
 
@@ -147,3 +148,24 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
             yield item
             yield item
     finally:
     finally:
         event.set()
         event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 5 - 1
hivemind/optim/performance_ema.py → hivemind/utils/performance_ema.py

@@ -37,6 +37,10 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
         return self.samples_per_second
 
 
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
     @contextmanager
     @contextmanager
     def pause(self):
     def pause(self):
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
@@ -44,8 +48,8 @@ class PerformanceEMA:
         try:
         try:
             yield
             yield
         finally:
         finally:
-            self.timestamp = time.perf_counter()
             self.paused = was_paused
             self.paused = was_paused
+            self.reset_timer()
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"

+ 5 - 2
tests/conftest.py

@@ -1,13 +1,13 @@
 import asyncio
 import asyncio
 import gc
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 from contextlib import suppress
 
 
 import psutil
 import psutil
 import pytest
 import pytest
 
 
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.mpfuture import MPFuture
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -33,6 +33,9 @@ def event_loop():
 def cleanup_children():
 def cleanup_children():
     yield
     yield
 
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
     gc.collect()  # Call .__del__() for removed objects
 
 
     children = psutil.Process().children(recursive=True)
     children = psutil.Process().children(recursive=True)

+ 2 - 2
tests/test_averaging.py

@@ -481,7 +481,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
 
     x1 = torch.randn(n_dims, requires_grad=True)
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         opt1,
         average_gradients=True,
         average_gradients=True,
         average_parameters=True,
         average_parameters=True,
@@ -492,7 +492,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
 
     x2 = torch.randn(n_dims, requires_grad=True)
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         opt2,
         average_gradients=True,
         average_gradients=True,
         average_parameters=True,
         average_parameters=True,

+ 172 - 0
tests/test_optimizer.py

@@ -0,0 +1,172 @@
+import time
+from functools import partial
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import hivemind
+from hivemind.averaging.control import AveragingStage
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.state_averager import TrainingStateAverager
+
+
+@pytest.mark.forked
+def test_grad_averager():
+    dht1 = hivemind.DHT(start=True)
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager1 = GradientAverager(
+        model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
+    )
+
+    dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager2 = GradientAverager(
+        model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
+    )
+
+    control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
+    control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
+
+    for i in range(10):
+        time.sleep(0.1)
+        if i % 3 == 0:
+            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1.backward()
+            averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
+            model1.zero_grad()
+        else:
+            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2.backward()
+            averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
+            # note: we do not call zero grad here because reuse_grad_buffers=True
+
+    assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
+    peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
+    assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
+    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
+
+    assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
+    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
+
+    averager1.step(control=control1, wait=False)
+    averager2.step(control=control2, wait=False)
+    for step in (control1, control2):
+        step.result()  # wait for all-reduce to finish
+
+    peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
+    peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
+    ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
+    with averager1.use_averaged_gradients():
+        assert torch.allclose(model1.w.grad, ref_average)
+    with averager2.use_averaged_gradients():
+        assert torch.allclose(model2.w.grad, ref_average)
+
+    # after no longer use_averaged_gradients
+    assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
+    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+)
+def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    torch.manual_seed(1337)
+    torch.use_deterministic_algorithms(True)
+    # note: use_deterministic_algorithms does not affect further tests because this test is forked
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        sync_epoch_when_averaging=sync_epoch_when_averaging,
+        average_opt_statistics=("exp_avg_sq",),
+        offload_optimizer=offload_optimizer,
+        reuse_tensors=reuse_tensors,
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+    )
+    avgr2 = TrainingStateAverager(
+        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+    )
+
+    x = torch.ones(2)
+
+    for step in range(20):
+        F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
+        avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
+
+        F.mse_loss(model2(x), -torch.ones(3)).backward()
+        avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
+
+    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
+    assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
+    assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
+
+    stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    assert not torch.allclose(stats1, stats2)
+
+    avgr1.step(increment_epoch=True)
+
+    avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+    avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+
+    avgr1.step(wait_for_delayed_update=True)
+    avgr2.step(wait_for_delayed_update=True)
+
+    assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
+    assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert avgr1.local_epoch == 2
+    assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.SGD, lr=0.1),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+    )
+
+    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+
+    avgr2.local_epoch = 1337
+    model2.weight.data[...] = 42
+    time.sleep(0.1)
+
+    avgr1.load_state_from_peers()
+    assert avgr1.local_epoch == 1337
+    assert torch.all(model1.weight == 42).item()
+    assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)

+ 19 - 1
tests/test_util_modules.py

@@ -11,7 +11,6 @@ import torch
 
 
 import hivemind
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -28,8 +27,10 @@ from hivemind.utils.asyncio import (
     attach_event_on_finished,
     attach_event_on_finished,
     azip,
     azip,
     cancel_and_wait,
     cancel_and_wait,
+    enter_asynchronously,
 )
 )
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -538,6 +539,23 @@ async def test_cancel_and_wait():
     assert not await cancel_and_wait(task_with_error)
     assert not await cancel_and_wait(task_with_error)
 
 
 
 
+@pytest.mark.asyncio
+async def test_async_context():
+    lock = mp.Lock()
+
+    async def coro1():
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.2)
+
+    async def coro2():
+        await asyncio.sleep(0.1)
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.1)
+
+    await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
+    # running this without enter_asynchronously would deadlock the event loop
+
+
 def test_batch_tensor_descriptor_msgpack():
 def test_batch_tensor_descriptor_msgpack():
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))