|
@@ -14,22 +14,25 @@ import torch
|
|
|
import numpy as np
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID
|
|
|
+from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
|
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
|
from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
|
|
|
-from hivemind.utils import get_logger, Endpoint, Port, MPFuture, GRPC_KEEPALIVE_OPTIONS, get_dht_time, MSGPackSerializer
|
|
|
+from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
|
|
|
+ serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
|
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
|
+from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
|
+from hivemind.utils.serializer import PickleSerializer, MSGPackSerializer
|
|
|
+from hivemind.utils import Endpoint, Port, MPFuture, get_logger
|
|
|
|
|
|
# flavour types
|
|
|
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
|
-
|
|
|
DataForGather = Any
|
|
|
logger = get_logger(__name__)
|
|
|
+DEFAULT_CHUNK_SIZE_BYTES = 2 ** 16
|
|
|
|
|
|
|
|
|
class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragingServicer):
|
|
|
"""
|
|
|
- **Warning!** Decentralized averager is in active development, some critical functionality is still underway
|
|
|
|
|
|
Parameter averaging service. A trainer can run this service in background to periodically average his parameters
|
|
|
with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small
|
|
@@ -103,6 +106,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
self._averaged_tensors = tuple(averaged_tensors)
|
|
|
self.lock_averaged_tensors = mp.Lock()
|
|
|
+ self.last_updated: DHTExpiration = -float('inf')
|
|
|
for tensor in self._averaged_tensors:
|
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
|
tensor.share_memory_()
|
|
@@ -122,6 +126,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
|
+ hivemind.run_in_background(self._background_thread_fetch_current_state_if_asked)
|
|
|
|
|
|
@property
|
|
|
def port(self) -> Optional[Port]:
|
|
@@ -157,6 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._pending_group_assembled.set()
|
|
|
await server.start()
|
|
|
self.ready.set()
|
|
|
+ asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
while True:
|
|
|
method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
|
|
@@ -195,10 +201,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
- self.pipe.send(('_step', [], dict(future=_future, gather=gather, allow_retries=allow_retries, timeout=timeout)))
|
|
|
+ gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
|
+ self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
|
|
|
+ allow_retries=allow_retries, timeout=timeout)))
|
|
|
return future.result() if wait else future
|
|
|
|
|
|
- async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
|
|
|
+ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
|
|
|
loop = asyncio.get_event_loop()
|
|
|
start_time = get_dht_time()
|
|
|
group_id = None
|
|
@@ -206,7 +214,6 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
while not future.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
- gather_binary = self.serializer.dumps(gather)
|
|
|
allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
|
|
|
if allreduce_group is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
@@ -245,6 +252,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
assert len(local_tensors) == len(self._averaged_tensors)
|
|
|
for tensor, update in zip(local_tensors, averaging_deltas):
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
def get_tensors(self) -> Sequence[torch.Tensor]:
|
|
@@ -255,6 +263,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
"""
|
|
|
with self.lock_averaged_tensors:
|
|
|
yield self._averaged_tensors
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
|
|
|
async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
@@ -279,6 +288,137 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
async for message in group.rpc_aggregate_part(achain(aiter(request), stream), context):
|
|
|
yield message
|
|
|
|
|
|
+ async def _declare_for_download_periodically(self):
|
|
|
+ download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
|
|
|
+ while True:
|
|
|
+ asyncio.create_task(asyncio.wait_for(self.dht.store(
|
|
|
+ download_key, subkey=self.endpoint, value=self.last_updated,
|
|
|
+ expiration_time=get_dht_time() + self._matchmaking.averaging_expiration, return_future=True),
|
|
|
+ timeout=self._matchmaking.averaging_expiration))
|
|
|
+ await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
+
|
|
|
+ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext
|
|
|
+ ) -> AsyncIterator[averaging_pb2.DownloadData]:
|
|
|
+ """
|
|
|
+ Get the up-to-date trainer state from a peer.
|
|
|
+ The state consists of two parts: (metadata, tensors)
|
|
|
+
|
|
|
+ - metadata is a small pickle-serialized entry meant to store scalars and hyperparameters
|
|
|
+ - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
|
+ """
|
|
|
+ chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
|
|
|
+ metadata, tensors = await self._get_current_state_from_host_process()
|
|
|
+
|
|
|
+ for tensor in tensors:
|
|
|
+ for part in split_for_streaming(serialize_torch_tensor(tensor), chunk_size_bytes):
|
|
|
+ if metadata is not None:
|
|
|
+ yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
|
+ metadata = None
|
|
|
+ else:
|
|
|
+ yield averaging_pb2.DownloadData(tensor_part=part)
|
|
|
+
|
|
|
+ def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
|
|
|
+ """
|
|
|
+ Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
|
+ :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
|
|
|
+ """
|
|
|
+ with self.get_tensors() as tensors:
|
|
|
+ return dict(group_key=self.get_group_bits()), tensors
|
|
|
+
|
|
|
+ async def _get_current_state_from_host_process(self):
|
|
|
+ """ Executed in the averager process inside rpc_download_state """
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
|
|
|
+ return await future
|
|
|
+
|
|
|
+ def _background_thread_fetch_current_state_if_asked(self):
|
|
|
+ """ Executed in the host process as a background thread. """
|
|
|
+ while True:
|
|
|
+ trigger, future = self.pipe.recv()
|
|
|
+ assert trigger == '_TRIGGER_GET_CURRENT_STATE'
|
|
|
+ try:
|
|
|
+ state_metadata, state_tensors = self.get_current_state()
|
|
|
+ # note: serialize here to avoid initializing cuda in the guest process
|
|
|
+ state_metadata = PickleSerializer.dumps(state_metadata)
|
|
|
+ state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
|
|
|
+ for tensor in state_tensors)
|
|
|
+ future.set_result((state_metadata, state_tensors))
|
|
|
+ except BaseException as e:
|
|
|
+ future.set_exception(e)
|
|
|
+ logger.warning(e)
|
|
|
+ continue
|
|
|
+
|
|
|
+ def load_state_from_peers(self, wait=True) -> Optional[Any]:
|
|
|
+ """ Try to download the latest optimizer state one of the existing peer """
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
|
|
|
+ return future.result() if wait else future
|
|
|
+
|
|
|
+ async def _load_state_from_peers(self, future: MPFuture):
|
|
|
+ key_manager = self._matchmaking.group_key_manager
|
|
|
+ peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
|
+ peer_priority = {peer: float(info.value) for peer, info in peer_priority.items()
|
|
|
+ if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))}
|
|
|
+
|
|
|
+ if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
|
|
|
+ logger.info(f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}.")
|
|
|
+ future.set_result(None)
|
|
|
+ return
|
|
|
+
|
|
|
+ metadata = None
|
|
|
+ for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
|
|
|
+ if peer != self.endpoint:
|
|
|
+ logger.info(f"Downloading parameters from peer {peer}")
|
|
|
+ stream = None
|
|
|
+ try:
|
|
|
+ leader_stub = ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
|
|
|
+ stream = leader_stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
+ current_tensor_parts, tensors = [], []
|
|
|
+ async for message in stream:
|
|
|
+ if message.metadata:
|
|
|
+ metadata = PickleSerializer.loads(message.metadata)
|
|
|
+ if message.tensor_part.dtype and current_tensor_parts:
|
|
|
+ # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
|
|
|
+ tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
+ current_tensor_parts = []
|
|
|
+ current_tensor_parts.append(message.tensor_part)
|
|
|
+ if current_tensor_parts:
|
|
|
+ tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
|
+ future.set_result((metadata, tensors))
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
+ return
|
|
|
+ except grpc.aio.AioRpcError as e:
|
|
|
+ logger.info(f"Failed to download state from {peer} - {e}")
|
|
|
+ finally:
|
|
|
+ if stream is not None:
|
|
|
+ await stream.code()
|
|
|
+
|
|
|
+ else:
|
|
|
+ logger.warning("Averager could not load state from peers: found no active peers.")
|
|
|
+ future.set_result(None)
|
|
|
+
|
|
|
+ def get_group_bits(self, wait: bool = True):
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ self.pipe.send(('_get_group_bits', [], dict(future=_future)))
|
|
|
+ return future.result() if wait else future
|
|
|
+
|
|
|
+ async def _get_group_bits(self, future: MPFuture):
|
|
|
+ future.set_result(self._matchmaking.group_key_manager.group_bits)
|
|
|
+
|
|
|
+ def set_group_bits(self, group_bits: str, wait: bool = True):
|
|
|
+ future, _future = MPFuture.make_pair()
|
|
|
+ assert all(bit in '01' for bit in group_bits)
|
|
|
+ self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
|
|
|
+ return future.result() if wait else future
|
|
|
+
|
|
|
+ async def _set_group_bits(self, group_bits: str, future: MPFuture):
|
|
|
+ try:
|
|
|
+ self._matchmaking.group_key_manager.group_bits = group_bits
|
|
|
+ return future.set_result(None)
|
|
|
+ except Exception as e:
|
|
|
+ if not future.done():
|
|
|
+ future.set_exception(e)
|
|
|
+
|
|
|
|
|
|
def is_power_of_two(n):
|
|
|
""" Check whether n is a power of 2 """
|