Răsfoiți Sursa

Merge remote-tracking branch 'origin/master' into tr

Michael Diskin 4 ani în urmă
părinte
comite
4f36589aa5

+ 10 - 6
hivemind/averaging/allreduce.py

@@ -8,7 +8,7 @@ from hivemind.averaging.partition import AllreduceException, TensorPartContainer
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext
+from hivemind.utils.asyncio import achain, aenumerate, afirst, amap_in_executor, anext, as_aiter
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 # flavour types
@@ -153,13 +153,17 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            loop = asyncio.get_event_loop()
             code = None
             stream = self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, msg in aenumerate(stream):
+            async for part_index, (averaged_part_delta, msg) in aenumerate(
+                amap_in_executor(
+                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
+                    stream,
+                    max_prefetch=self.tensor_part_container.prefetch,
+                )
+            ):
                 if code is None:
                     code = msg.code
-                averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
 
             if code != averaging_pb2.AVERAGED_PART:
@@ -193,7 +197,7 @@ class AllReduceRunner(ServicerBase):
         elif request.code == averaging_pb2.PART_FOR_AVERAGING:
             try:
                 sender_index = self.sender_peer_ids.index(context.remote_id)
-                async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
+                async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
                     yield msg
 
             except Exception as e:
@@ -232,7 +236,7 @@ class AllReduceRunner(ServicerBase):
     async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
         error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
         # Coroutines are lazy, so we take the first item to start the couroutine's execution
-        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(aiter(error)))
+        await afirst(self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""

+ 61 - 46
hivemind/averaging/averager.py

@@ -25,7 +25,7 @@ from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2, runtime_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
@@ -197,6 +197,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def peer_id(self) -> PeerID:
         return self.dht.peer_id
 
+    @property
+    def request_timeout(self):
+        return self._matchmaking.request_timeout
+
     def run(self):
         """
         Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
@@ -211,48 +215,56 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                self._p2p = await self.dht.replicate_p2p()
+                if not self.client_mode:
+                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                else:
+                    logger.debug(f"The averager is running in client mode.")
+
+                self._matchmaking = Matchmaking(
+                    self._p2p,
+                    self.schema_hash,
+                    self.dht,
+                    client_mode=self.client_mode,
+                    **self.matchmaking_kwargs,
+                )
+                if not self.client_mode:
+                    asyncio.create_task(self._declare_for_download_periodically())
+
+                self._pending_group_assembled = asyncio.Event()
+                self._pending_group_assembled.set()
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
                 try:
-                    self._p2p = await self.dht.replicate_p2p()
-                    if not self.client_mode:
-                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                    else:
-                        logger.debug(f"The averager is running in client mode.")
-
-                    self._matchmaking = Matchmaking(
-                        self._p2p,
-                        self.schema_hash,
-                        self.dht,
-                        client_mode=self.client_mode,
-                        **self.matchmaking_kwargs,
-                    )
-                    if not self.client_mode:
-                        asyncio.create_task(self._declare_for_download_periodically())
-
-                    self._pending_group_assembled = asyncio.Event()
-                    self._pending_group_assembled.set()
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._matchmaking.request_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            loop.run_until_complete(_run())
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
+                try:
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self.request_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
@@ -484,7 +496,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
-        async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
+        async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 
     async def _declare_for_download_periodically(self):
@@ -542,7 +554,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
         return await future
 
-    def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
+    def load_state_from_peers(
+        self, wait: bool = True, timeout: Optional[float] = None
+    ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
         """
         Try to download the latest optimizer state one of the existing peer.
         :returns: on success, return a 2-tuple with (metadata, tensors), where
@@ -554,7 +568,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         future = MPFuture()
         self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
-        return future.result() if wait else future
+        return future.result(timeout=timeout) if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
         try:
@@ -579,7 +593,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
                         stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
-                        async for message in stream:
+
+                        async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -603,7 +618,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         finally:
             if not future.done():
-                logger.warning("Averager could not load state from peers: all requests have failed.")
+                logger.warning("Averager could not load state from peers: none of the requests succeeded.")
                 future.set_result(None)
 
     def get_group_bits(self, wait: bool = True):

+ 2 - 2
hivemind/averaging/partition.py

@@ -13,7 +13,7 @@ from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.compression import get_nbytes_per_value, serialize_torch_tensor
 
 T = TypeVar("T")
-DEFAULT_PART_SIZE_BYTES = 2 ** 19
+DEFAULT_PART_SIZE_BYTES = 2 ** 16
 
 
 class TensorPartContainer:
@@ -33,7 +33,7 @@ class TensorPartContainer:
         peer_fractions: Sequence[float],
         compression_type: Union["CompressionType", Sequence["CompressionType"]] = CompressionType.NONE,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
-        prefetch: int = 1,
+        prefetch: int = 5,
     ):
         if not isinstance(compression_type, Sequence):
             compression_type = [compression_type] * len(tensors)

+ 43 - 37
hivemind/dht/__init__.py

@@ -102,45 +102,51 @@ class DHT(mp.Process):
 
     def run(self) -> None:
         """Serve DHT forever. This function will not return until DHT node is shut down"""
-        loop = switch_to_uvloop()
-
-        with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
-            async def _run():
+        loop = switch_to_uvloop()
+        pipe_semaphore = asyncio.Semaphore(value=0)
+        loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
+
+        async def _run():
+            try:
+                if self._daemon_listen_maddr is not None:
+                    replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
+                else:
+                    replicated_p2p = None
+
+                self._node = await DHTNode.create(
+                    initial_peers=self.initial_peers,
+                    num_workers=self.num_workers,
+                    record_validator=self._record_validator,
+                    p2p=replicated_p2p,
+                    **self.kwargs,
+                )
+            except Exception as e:
+                # Loglevel is DEBUG since normally the exception is propagated to the caller
+                logger.debug(e, exc_info=True)
+                self._ready.set_exception(e)
+                return
+            self._ready.set_result(None)
+
+            while True:
+                try:
+                    await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self._node.protocol.wait_timeout)
+                except asyncio.TimeoutError:
+                    pass
+                if not self._inner_pipe.poll():
+                    continue
                 try:
-                    if self._daemon_listen_maddr is not None:
-                        replicated_p2p = await P2P.replicate(self._daemon_listen_maddr)
-                    else:
-                        replicated_p2p = None
-
-                    self._node = await DHTNode.create(
-                        initial_peers=self.initial_peers,
-                        num_workers=self.num_workers,
-                        record_validator=self._record_validator,
-                        p2p=replicated_p2p,
-                        **self.kwargs,
-                    )
-                except Exception as e:
-                    # Loglevel is DEBUG since normally the exception is propagated to the caller
-                    logger.debug(e, exc_info=True)
-                    self._ready.set_exception(e)
-                    return
-                self._ready.set_result(None)
-
-                while True:
-                    try:
-                        method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
-                    except (OSError, ConnectionError) as e:
-                        logger.exception(e)
-                        await asyncio.sleep(self._node.protocol.wait_timeout)
-                        continue
-                    task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
-                    if method == "_shutdown":
-                        await task
-                        break
-
-            coro = _run()
-            loop.run_until_complete(coro)
+                    method, args, kwargs = self._inner_pipe.recv()
+                except (OSError, ConnectionError, RuntimeError) as e:
+                    logger.exception(e)
+                    await asyncio.sleep(self._node.protocol.wait_timeout)
+                    continue
+                task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
+                if method == "_shutdown":
+                    await task
+                    break
+
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """

+ 13 - 2
hivemind/optim/collaborative.py

@@ -85,6 +85,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
     :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param scheduler: if specified, use this scheduler to update optimizer learning rate
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
@@ -114,6 +115,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         performance_ema_alpha: float = 0.1,
         metadata_expiration: float = 60.0,
         averaging_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
         step_tolerance: int = 1,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
@@ -137,7 +139,9 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             default_refresh_period,
         )
         self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
+        self.averaging_timeout = averaging_timeout
+        self.load_state_timeout = load_state_timeout
+        self.metadata_expiration = metadata_expiration
         self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
         self.client_mode, self.step_tolerance = client_mode, step_tolerance
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
@@ -185,7 +189,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         with self.lock_collaboration_state:
-            self.averager.load_state_from_peers(**kwargs)
+            while True:
+                try:
+                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
             self.update_scheduler()

+ 2 - 2
hivemind/p2p/p2p_daemon.py

@@ -15,7 +15,7 @@ import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
-from hivemind.utils.asyncio import aiter, asingle
+from hivemind.utils.asyncio import as_aiter, asingle
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -480,7 +480,7 @@ class P2P:
         input: Union[TInputProtobuf, TInputStream],
         output_protobuf_type: Type[Message],
     ) -> TOutputStream:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        requests = input if isinstance(input, AsyncIterableABC) else as_aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:

+ 12 - 1
hivemind/utils/asyncio.py

@@ -28,7 +28,7 @@ async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
     return await aiter.__anext__()
 
 
-async def aiter(*args: T) -> AsyncIterator[T]:
+async def as_aiter(*args: T) -> AsyncIterator[T]:
     """create an asynchronous iterator from a sequence of values"""
     for arg in args:
         yield arg
@@ -127,3 +127,14 @@ async def amap_in_executor(
     finally:
         if not task.done():
             task.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+    """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
+    # based on https://stackoverflow.com/a/50245879
+    iterator = iterable.__aiter__()
+    while True:
+        try:
+            yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
+        except StopAsyncIteration:
+            break

+ 2 - 2
setup.py

@@ -14,8 +14,8 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.4"
-P2PD_CHECKSUM = "194dca06116fdd36bc4b681d18f3b9cb"
+P2PD_VERSION = "v0.3.5"
+P2PD_CHECKSUM = "affea8ec63dbe2423ef7453718b5798d"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
 
 here = os.path.abspath(os.path.dirname(__file__))

+ 33 - 13
tests/test_util_modules.py

@@ -17,9 +17,10 @@ from hivemind.utils.asyncio import (
     achain,
     aenumerate,
     afirst,
-    aiter,
+    aiter_with_timeout,
     amap_in_executor,
     anext,
+    as_aiter,
     asingle,
     azip,
     cancel_and_wait,
@@ -478,20 +479,23 @@ def test_generic_data_classes():
 
 @pytest.mark.asyncio
 async def test_asyncio_utils():
-    res = [i async for i, item in aenumerate(aiter("a", "b", "c"))]
+    res = [i async for i, item in aenumerate(as_aiter("a", "b", "c"))]
     assert res == list(range(len(res)))
 
     num_steps = 0
-    async for elem in amap_in_executor(lambda x: x ** 2, aiter(*range(100)), max_prefetch=5):
+    async for elem in amap_in_executor(lambda x: x ** 2, as_aiter(*range(100)), max_prefetch=5):
         assert elem == num_steps ** 2
         num_steps += 1
     assert num_steps == 100
 
-    ours = [elem async for elem in amap_in_executor(max, aiter(*range(7)), aiter(*range(-50, 50, 10)), max_prefetch=1)]
+    ours = [
+        elem
+        async for elem in amap_in_executor(max, as_aiter(*range(7)), as_aiter(*range(-50, 50, 10)), max_prefetch=1)
+    ]
     ref = list(map(max, range(7), range(-50, 50, 10)))
     assert ours == ref
 
-    ours = [row async for row in azip(aiter("a", "b", "c"), aiter(1, 2, 3))]
+    ours = [row async for row in azip(as_aiter("a", "b", "c"), as_aiter(1, 2, 3))]
     ref = list(zip(["a", "b", "c"], [1, 2, 3]))
     assert ours == ref
 
@@ -507,18 +511,34 @@ async def test_asyncio_utils():
     with pytest.raises(StopAsyncIteration):
         await anext(iterator)
 
-    assert [item async for item in achain(_aiterate(), aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
+    assert [item async for item in achain(_aiterate(), as_aiter(*range(5)))] == ["foo", "bar", "baz"] + list(range(5))
 
-    assert await asingle(aiter(1)) == 1
+    assert await asingle(as_aiter(1)) == 1
     with pytest.raises(ValueError):
-        await asingle(aiter())
+        await asingle(as_aiter())
     with pytest.raises(ValueError):
-        await asingle(aiter(1, 2, 3))
+        await asingle(as_aiter(1, 2, 3))
 
-    assert await afirst(aiter(1)) == 1
-    assert await afirst(aiter()) is None
-    assert await afirst(aiter(), -1) == -1
-    assert await afirst(aiter(1, 2, 3)) == 1
+    assert await afirst(as_aiter(1)) == 1
+    assert await afirst(as_aiter()) is None
+    assert await afirst(as_aiter(), -1) == -1
+    assert await afirst(as_aiter(1, 2, 3)) == 1
+
+    async def iterate_with_delays(delays):
+        for i, delay in enumerate(delays):
+            await asyncio.sleep(delay)
+            yield i
+
+    async for _ in aiter_with_timeout(iterate_with_delays([0.1] * 5), timeout=0.2):
+        pass
+
+    sleepy_aiter = iterate_with_delays([0.1, 0.1, 0.3, 0.1, 0.1])
+    num_steps = 0
+    with pytest.raises(asyncio.TimeoutError):
+        async for _ in aiter_with_timeout(sleepy_aiter, timeout=0.2):
+            num_steps += 1
+
+    assert num_steps == 2
 
 
 @pytest.mark.asyncio