Bladeren bron

move PerformanceEMA to utils, TrainingAverager to optim, update utils (#405)

* implement and test async wrapper for ContextManager (used in DecentralizedAverager and ProgressTracker)
* implement .reset_timer in PerformanceEMA (used when progress was reset, e.g. with fp16 gradient overflow, which should not affect samples per second)
* move PerformanceEMA to hivemind.utils (rationale: will be used in hivemind.moe in @mryab 's pipelining exps)
* move TrainingAverager to hivemind.optim (for compliance with hivemind.Optimizer and future deprecation in favour of Training StateAverager)
* fix process-wide RSA keys in the validator

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 jaren geleden
bovenliggende
commit
ed4204009f

+ 2 - 1
hivemind/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.moe import (
@@ -16,6 +16,7 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 10 - 11
hivemind/averaging/averager.py

@@ -32,7 +32,15 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -453,7 +461,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     servicer_type=type(self),
@@ -505,15 +513,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
 
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
-
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:

+ 1 - 0
hivemind/optim/__init__.py

@@ -3,3 +3,4 @@ from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 import torch.optim
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -9,14 +9,14 @@ import numpy as np
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 1 - 1
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 import torch
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)

+ 0 - 0
hivemind/averaging/training.py → hivemind/optim/training_averager.py


+ 23 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,8 @@
 import asyncio
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 import uvloop
 
@@ -147,3 +148,24 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
             yield item
     finally:
         event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 5 - 1
hivemind/optim/performance_ema.py → hivemind/utils/performance_ema.py

@@ -37,6 +37,10 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
 
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
     @contextmanager
     def pause(self):
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
@@ -44,8 +48,8 @@ class PerformanceEMA:
         try:
             yield
         finally:
-            self.timestamp = time.perf_counter()
             self.paused = was_paused
+            self.reset_timer()
 
     def __repr__(self):
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"

+ 5 - 2
tests/conftest.py

@@ -1,13 +1,13 @@
 import asyncio
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 
 import psutil
 import pytest
 
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.mpfuture import MPFuture
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -33,6 +33,9 @@ def event_loop():
 def cleanup_children():
     yield
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
 
     children = psutil.Process().children(recursive=True)

+ 2 - 2
tests/test_averaging.py

@@ -481,7 +481,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         average_gradients=True,
         average_parameters=True,
@@ -492,7 +492,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         average_gradients=True,
         average_parameters=True,

+ 19 - 1
tests/test_util_modules.py

@@ -11,7 +11,6 @@ import torch
 
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -28,8 +27,10 @@ from hivemind.utils.asyncio import (
     attach_event_on_finished,
     azip,
     cancel_and_wait,
+    enter_asynchronously,
 )
 from hivemind.utils.mpfuture import InvalidStateError
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 @pytest.mark.forked
@@ -538,6 +539,23 @@ async def test_cancel_and_wait():
     assert not await cancel_and_wait(task_with_error)
 
 
+@pytest.mark.asyncio
+async def test_async_context():
+    lock = mp.Lock()
+
+    async def coro1():
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.2)
+
+    async def coro2():
+        await asyncio.sleep(0.1)
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.1)
+
+    await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
+    # running this without enter_asynchronously would deadlock the event loop
+
+
 def test_batch_tensor_descriptor_msgpack():
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))