|
|
@@ -9,7 +9,6 @@ import multiprocessing as mp
|
|
|
import os
|
|
|
import threading
|
|
|
import weakref
|
|
|
-from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
from dataclasses import asdict
|
|
|
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
@@ -21,12 +20,18 @@ from hivemind.averaging.group_info import GroupInfo
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
|
+from hivemind.compression import (
|
|
|
+ CompressionBase,
|
|
|
+ CompressionInfo,
|
|
|
+ NoCompression,
|
|
|
+ deserialize_torch_tensor,
|
|
|
+ serialize_torch_tensor,
|
|
|
+)
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
-from hivemind.proto import averaging_pb2, runtime_pb2
|
|
|
+from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
-from hivemind.utils.asyncio import achain, aiter, anext, switch_to_uvloop
|
|
|
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
+from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
|
@@ -51,7 +56,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
|
:param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
|
|
|
note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
|
|
|
- :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
|
|
|
+ :param compression: optionally compress tensors with this compression algorithm before running all-reduce
|
|
|
+ :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
|
|
|
+ :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
|
|
|
:param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
|
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
|
@@ -102,7 +109,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
averaging_alpha: float = 1.0,
|
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
|
allreduce_timeout: Optional[float] = None,
|
|
|
- compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
|
+ compression: CompressionBase = NoCompression(),
|
|
|
+ state_compression: CompressionBase = NoCompression(),
|
|
|
+ tensor_infos: Optional[Sequence[CompressionInfo]] = None,
|
|
|
bandwidth: Optional[float] = None,
|
|
|
min_vector_size: int = 0,
|
|
|
auxiliary: bool = False,
|
|
|
@@ -158,7 +167,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
request_timeout=request_timeout,
|
|
|
)
|
|
|
self.allreduce_kwargs = dict(
|
|
|
- compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
|
|
|
+ compression=compression,
|
|
|
+ part_size_bytes=part_size_bytes,
|
|
|
+ min_vector_size=min_vector_size,
|
|
|
)
|
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
@@ -169,6 +180,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if allow_state_sharing is None:
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
+ self.state_compression = state_compression
|
|
|
+ self.tensor_infos = tensor_infos
|
|
|
|
|
|
self._ready = MPFuture()
|
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
|
@@ -197,6 +210,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def peer_id(self) -> PeerID:
|
|
|
return self.dht.peer_id
|
|
|
|
|
|
+ @property
|
|
|
+ def request_timeout(self):
|
|
|
+ return self._matchmaking.request_timeout
|
|
|
+
|
|
|
def run(self):
|
|
|
"""
|
|
|
Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
|
|
|
@@ -211,48 +228,56 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""Serve DecentralizedAverager forever. This function will not return until the averager is shut down"""
|
|
|
loop = switch_to_uvloop()
|
|
|
# initialize asyncio synchronization primitives in this event loop
|
|
|
- with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
|
|
|
|
|
|
- async def _run():
|
|
|
+ pipe_semaphore = asyncio.Semaphore(value=0)
|
|
|
+ loop.add_reader(self._inner_pipe.fileno(), pipe_semaphore.release)
|
|
|
+
|
|
|
+ async def _run():
|
|
|
+ try:
|
|
|
+ self._p2p = await self.dht.replicate_p2p()
|
|
|
+ if not self.client_mode:
|
|
|
+ await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
+ else:
|
|
|
+ logger.debug(f"The averager is running in client mode.")
|
|
|
+
|
|
|
+ self._matchmaking = Matchmaking(
|
|
|
+ self._p2p,
|
|
|
+ self.schema_hash,
|
|
|
+ self.dht,
|
|
|
+ client_mode=self.client_mode,
|
|
|
+ **self.matchmaking_kwargs,
|
|
|
+ )
|
|
|
+ if not self.client_mode:
|
|
|
+ asyncio.create_task(self._declare_for_download_periodically())
|
|
|
+
|
|
|
+ self._pending_group_assembled = asyncio.Event()
|
|
|
+ self._pending_group_assembled.set()
|
|
|
+ except Exception as e:
|
|
|
+ # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
+ logger.debug(e, exc_info=True)
|
|
|
+ self._ready.set_exception(e)
|
|
|
+ return
|
|
|
+ self._ready.set_result(None)
|
|
|
+
|
|
|
+ while True:
|
|
|
try:
|
|
|
- self._p2p = await self.dht.replicate_p2p()
|
|
|
- if not self.client_mode:
|
|
|
- await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
|
|
|
- else:
|
|
|
- logger.debug(f"The averager is running in client mode.")
|
|
|
-
|
|
|
- self._matchmaking = Matchmaking(
|
|
|
- self._p2p,
|
|
|
- self.schema_hash,
|
|
|
- self.dht,
|
|
|
- client_mode=self.client_mode,
|
|
|
- **self.matchmaking_kwargs,
|
|
|
- )
|
|
|
- if not self.client_mode:
|
|
|
- asyncio.create_task(self._declare_for_download_periodically())
|
|
|
-
|
|
|
- self._pending_group_assembled = asyncio.Event()
|
|
|
- self._pending_group_assembled.set()
|
|
|
- except Exception as e:
|
|
|
- # Loglevel is DEBUG since normally the exception is propagated to the caller
|
|
|
- logger.debug(e, exc_info=True)
|
|
|
- self._ready.set_exception(e)
|
|
|
- return
|
|
|
- self._ready.set_result(None)
|
|
|
-
|
|
|
- while True:
|
|
|
- try:
|
|
|
- method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
|
|
|
- except (OSError, ConnectionError) as e:
|
|
|
- logger.exception(e)
|
|
|
- await asyncio.sleep(self._matchmaking.request_timeout)
|
|
|
- continue
|
|
|
- task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
- if method == "_shutdown":
|
|
|
- await task
|
|
|
- break
|
|
|
-
|
|
|
- loop.run_until_complete(_run())
|
|
|
+ await asyncio.wait_for(pipe_semaphore.acquire(), timeout=self.request_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ pass
|
|
|
+ if not self._inner_pipe.poll():
|
|
|
+ continue
|
|
|
+ try:
|
|
|
+ method, args, kwargs = self._inner_pipe.recv()
|
|
|
+ except (OSError, ConnectionError, RuntimeError) as e:
|
|
|
+ logger.exception(e)
|
|
|
+ await asyncio.sleep(self.request_timeout)
|
|
|
+ continue
|
|
|
+ task = asyncio.create_task(getattr(self, method)(*args, **kwargs))
|
|
|
+ if method == "_shutdown":
|
|
|
+ await task
|
|
|
+ break
|
|
|
+
|
|
|
+ loop.run_until_complete(_run())
|
|
|
|
|
|
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
|
|
"""
|
|
|
@@ -351,7 +376,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
future.set_result(
|
|
|
await asyncio.wait_for(
|
|
|
- self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
|
|
|
+ self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
|
|
|
+ timeout=self._allreduce_timeout,
|
|
|
)
|
|
|
)
|
|
|
# averaging is finished, loop will now exit
|
|
|
@@ -484,7 +510,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
return
|
|
|
|
|
|
- async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
|
|
|
+ async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
|
|
|
yield message
|
|
|
|
|
|
async def _declare_for_download_periodically(self):
|
|
|
@@ -517,24 +543,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
if not self.allow_state_sharing:
|
|
|
return # deny request and direct peer to the next prospective averager
|
|
|
- metadata, tensors = await self._get_current_state_from_host_process()
|
|
|
+ metadata, tensors, infos = await self._get_current_state_from_host_process()
|
|
|
+ if infos is None:
|
|
|
+ infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
|
|
|
+ assert len(tensors) == len(infos)
|
|
|
|
|
|
- for tensor in tensors:
|
|
|
- for part in split_for_streaming(serialize_torch_tensor(tensor)):
|
|
|
+ for tensor, info in zip(tensors, infos):
|
|
|
+ for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
|
|
|
if metadata is not None:
|
|
|
yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
|
metadata = None
|
|
|
else:
|
|
|
yield averaging_pb2.DownloadData(tensor_part=part)
|
|
|
|
|
|
- def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
|
|
|
+ def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
|
|
|
"""
|
|
|
Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
|
:returns: a tuple of (small metadata, sequence of torch tensors)
|
|
|
:note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
|
|
|
"""
|
|
|
with self.get_tensors() as tensors:
|
|
|
- return dict(group_key=self.get_group_bits()), tensors
|
|
|
+ return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
|
|
|
|
|
|
async def _get_current_state_from_host_process(self):
|
|
|
"""Executed in the averager process inside rpc_download_state"""
|
|
|
@@ -542,7 +571,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
self._inner_pipe.send(("_TRIGGER_GET_CURRENT_STATE", future))
|
|
|
return await future
|
|
|
|
|
|
- def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
+ def load_state_from_peers(
|
|
|
+ self, wait: bool = True, timeout: Optional[float] = None
|
|
|
+ ) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
|
|
|
"""
|
|
|
Try to download the latest optimizer state one of the existing peer.
|
|
|
:returns: on success, return a 2-tuple with (metadata, tensors), where
|
|
|
@@ -554,7 +585,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
future = MPFuture()
|
|
|
self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
- return future.result() if wait else future
|
|
|
+ return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
async def _load_state_from_peers(self, future: MPFuture):
|
|
|
try:
|
|
|
@@ -579,7 +610,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
|
|
|
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
- async for message in stream:
|
|
|
+
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
|
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
|
@@ -603,7 +635,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
finally:
|
|
|
if not future.done():
|
|
|
- logger.warning("Averager could not load state from peers: all requests have failed.")
|
|
|
future.set_result(None)
|
|
|
|
|
|
def get_group_bits(self, wait: bool = True):
|
|
|
@@ -666,7 +697,11 @@ def _background_thread_fetch_current_state(
|
|
|
get_current_state = get_current_state_ref()
|
|
|
if get_current_state is None:
|
|
|
break
|
|
|
- state_metadata, state_tensors = get_current_state()
|
|
|
+ state = get_current_state()
|
|
|
+ assert 0 < len(state) <= 3
|
|
|
+ if len(state) != 3:
|
|
|
+ state = tuple(state + (None,) * (3 - len(state)))
|
|
|
+ state_metadata, state_tensors, tensor_infos = state
|
|
|
del get_current_state
|
|
|
|
|
|
state_metadata = serializer.dumps(state_metadata)
|
|
|
@@ -674,7 +709,7 @@ def _background_thread_fetch_current_state(
|
|
|
tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
|
|
|
)
|
|
|
# note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
|
- future.set_result((state_metadata, state_tensors))
|
|
|
+ future.set_result((state_metadata, state_tensors, tensor_infos))
|
|
|
except BaseException as e:
|
|
|
future.set_exception(e)
|
|
|
logger.warning(e)
|