Ver código fonte

move PerformanceEMA to utils, TrainingAverager to optim, update utils (#405)

* implement and test async wrapper for ContextManager (used in DecentralizedAverager and ProgressTracker)
* implement .reset_timer in PerformanceEMA (used when progress was reset, e.g. with fp16 gradient overflow, which should not affect samples per second)
* move PerformanceEMA to hivemind.utils (rationale: will be used in hivemind.moe in @mryab 's pipelining exps)
* move TrainingAverager to hivemind.optim (for compliance with hivemind.Optimizer and future deprecation in favour of Training StateAverager)
* fix process-wide RSA keys in the validator

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 3 anos atrás
pai
commit
ed4204009f

+ 2 - 1
hivemind/__init__.py

@@ -1,4 +1,4 @@
-from hivemind.averaging import DecentralizedAverager, TrainingAverager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.compression import *
 from hivemind.compression import *
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.moe import (
 from hivemind.moe import (
@@ -16,6 +16,7 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
     DecentralizedSGD,
+    TrainingAverager,
 )
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *

+ 0 - 1
hivemind/averaging/__init__.py

@@ -1,2 +1 @@
 from hivemind.averaging.averager import DecentralizedAverager
 from hivemind.averaging.averager import DecentralizedAverager
-from hivemind.averaging.training import TrainingAverager

+ 10 - 11
hivemind/averaging/averager.py

@@ -32,7 +32,15 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
+from hivemind.utils.asyncio import (
+    achain,
+    aiter_with_timeout,
+    anext,
+    as_aiter,
+    azip,
+    enter_asynchronously,
+    switch_to_uvloop,
+)
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
 from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
@@ -453,7 +461,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
                 None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
             )
             )
 
 
-            async with self.get_tensors_async() as local_tensors:
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
                 allreduce = AllReduceRunner(
                 allreduce = AllReduceRunner(
                     p2p=self._p2p,
                     p2p=self._p2p,
                     servicer_type=type(self),
                     servicer_type=type(self),
@@ -505,15 +513,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors
 
 
-    @contextlib.asynccontextmanager
-    async def get_tensors_async(self) -> Sequence[torch.Tensor]:
-        """Like get_tensors, but uses an asynchronous contextmanager"""
-        try:
-            await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
-            yield self._averaged_tensors
-        finally:
-            self.lock_averaged_tensors.release()
-
     async def rpc_join_group(
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:

+ 1 - 0
hivemind/optim/__init__.py

@@ -3,3 +3,4 @@ from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
+from hivemind.optim.training_averager import TrainingAverager

+ 1 - 1
hivemind/optim/adaptive.py

@@ -2,8 +2,8 @@ from typing import Sequence
 
 
 import torch.optim
 import torch.optim
 
 
-from hivemind import TrainingAverager
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind.optim.training_averager import TrainingAverager
 
 
 
 
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
 class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):

+ 2 - 2
hivemind/optim/collaborative.py

@@ -9,14 +9,14 @@ import numpy as np
 import torch
 import torch
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
 
 
-from hivemind.averaging.training import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.performance_ema import PerformanceEMA
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)

+ 1 - 1
hivemind/optim/simple.py

@@ -4,9 +4,9 @@ from typing import Optional, Sequence, Tuple
 
 
 import torch
 import torch
 
 
-from hivemind.averaging import TrainingAverager
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.training_averager import TrainingAverager
 from hivemind.utils import get_dht_time, get_logger
 from hivemind.utils import get_dht_time, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)

+ 0 - 0
hivemind/averaging/training.py → hivemind/optim/training_averager.py


+ 23 - 1
hivemind/utils/asyncio.py

@@ -1,7 +1,8 @@
 import asyncio
 import asyncio
 import concurrent.futures
 import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
+from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
+from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
 
 
 import uvloop
 import uvloop
 
 
@@ -147,3 +148,24 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
             yield item
             yield item
     finally:
     finally:
         event.set()
         event.set()
+
+
+class _AsyncContextWrapper(AbstractAsyncContextManager):
+    """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
+
+    def __init__(self, context: AbstractContextManager):
+        self._context = context
+
+    async def __aenter__(self):
+        loop = asyncio.get_event_loop()
+        return await loop.run_in_executor(None, self._context.__enter__)
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        return self._context.__exit__(exc_type, exc_value, traceback)
+
+
+@asynccontextmanager
+async def enter_asynchronously(context: AbstractContextManager):
+    """Wrap a non-async context so that it can be entered asynchronously"""
+    async with _AsyncContextWrapper(context) as ret_value:
+        yield ret_value

+ 5 - 1
hivemind/optim/performance_ema.py → hivemind/utils/performance_ema.py

@@ -37,6 +37,10 @@ class PerformanceEMA:
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps)
         return self.samples_per_second
         return self.samples_per_second
 
 
+    def reset_timer(self):
+        """Reset the time since the last update so that the next task performance is counted from current time"""
+        self.timestamp = time.perf_counter()
+
     @contextmanager
     @contextmanager
     def pause(self):
     def pause(self):
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
         """While inside this context, EMA will not count the time passed towards the performance estimate"""
@@ -44,8 +48,8 @@ class PerformanceEMA:
         try:
         try:
             yield
             yield
         finally:
         finally:
-            self.timestamp = time.perf_counter()
             self.paused = was_paused
             self.paused = was_paused
+            self.reset_timer()
 
 
     def __repr__(self):
     def __repr__(self):
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"
         return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})"

+ 5 - 2
tests/conftest.py

@@ -1,13 +1,13 @@
 import asyncio
 import asyncio
 import gc
 import gc
-import multiprocessing as mp
 from contextlib import suppress
 from contextlib import suppress
 
 
 import psutil
 import psutil
 import pytest
 import pytest
 
 
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.mpfuture import MPFuture, SharedBytes
+from hivemind.utils.mpfuture import MPFuture
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -33,6 +33,9 @@ def event_loop():
 def cleanup_children():
 def cleanup_children():
     yield
     yield
 
 
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
     gc.collect()  # Call .__del__() for removed objects
     gc.collect()  # Call .__del__() for removed objects
 
 
     children = psutil.Process().children(recursive=True)
     children = psutil.Process().children(recursive=True)

+ 2 - 2
tests/test_averaging.py

@@ -481,7 +481,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
 
     x1 = torch.randn(n_dims, requires_grad=True)
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.averaging.TrainingAverager(
+    averager1 = hivemind.TrainingAverager(
         opt1,
         opt1,
         average_gradients=True,
         average_gradients=True,
         average_parameters=True,
         average_parameters=True,
@@ -492,7 +492,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 
 
     x2 = torch.randn(n_dims, requires_grad=True)
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.averaging.TrainingAverager(
+    averager2 = hivemind.TrainingAverager(
         opt2,
         opt2,
         average_gradients=True,
         average_gradients=True,
         average_parameters=True,
         average_parameters=True,

+ 19 - 1
tests/test_util_modules.py

@@ -11,7 +11,6 @@ import torch
 
 
 import hivemind
 import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
@@ -28,8 +27,10 @@ from hivemind.utils.asyncio import (
     attach_event_on_finished,
     attach_event_on_finished,
     azip,
     azip,
     cancel_and_wait,
     cancel_and_wait,
+    enter_asynchronously,
 )
 )
 from hivemind.utils.mpfuture import InvalidStateError
 from hivemind.utils.mpfuture import InvalidStateError
+from hivemind.utils.performance_ema import PerformanceEMA
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -538,6 +539,23 @@ async def test_cancel_and_wait():
     assert not await cancel_and_wait(task_with_error)
     assert not await cancel_and_wait(task_with_error)
 
 
 
 
+@pytest.mark.asyncio
+async def test_async_context():
+    lock = mp.Lock()
+
+    async def coro1():
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.2)
+
+    async def coro2():
+        await asyncio.sleep(0.1)
+        async with enter_asynchronously(lock):
+            await asyncio.sleep(0.1)
+
+    await asyncio.wait_for(asyncio.gather(coro1(), coro2()), timeout=0.5)
+    # running this without enter_asynchronously would deadlock the event loop
+
+
 def test_batch_tensor_descriptor_msgpack():
 def test_batch_tensor_descriptor_msgpack():
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))