|
@@ -32,7 +32,15 @@ from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
|
|
|
+from hivemind.utils.asyncio import (
|
|
|
+ achain,
|
|
|
+ aiter_with_timeout,
|
|
|
+ anext,
|
|
|
+ as_aiter,
|
|
|
+ azip,
|
|
|
+ enter_asynchronously,
|
|
|
+ switch_to_uvloop,
|
|
|
+)
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -453,7 +461,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
|
|
|
)
|
|
|
|
|
|
- async with self.get_tensors_async() as local_tensors:
|
|
|
+ async with enter_asynchronously(self.get_tensors()) as local_tensors:
|
|
|
allreduce = AllReduceRunner(
|
|
|
p2p=self._p2p,
|
|
|
servicer_type=type(self),
|
|
@@ -505,15 +513,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
with self.lock_averaged_tensors:
|
|
|
yield self._averaged_tensors
|
|
|
|
|
|
- @contextlib.asynccontextmanager
|
|
|
- async def get_tensors_async(self) -> Sequence[torch.Tensor]:
|
|
|
- """Like get_tensors, but uses an asynchronous contextmanager"""
|
|
|
- try:
|
|
|
- await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
|
|
|
- yield self._averaged_tensors
|
|
|
- finally:
|
|
|
- self.lock_averaged_tensors.release()
|
|
|
-
|
|
|
async def rpc_join_group(
|
|
|
self, request: averaging_pb2.JoinRequest, context: P2PContext
|
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|