|
@@ -11,6 +11,7 @@ import threading
|
|
import uuid
|
|
import uuid
|
|
import weakref
|
|
import weakref
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
|
+from dataclasses import asdict
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator
|
|
|
|
|
|
import grpc
|
|
import grpc
|
|
@@ -18,16 +19,18 @@ from grpc._cython.cygrpc import InternalError
|
|
import torch
|
|
import torch
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
-import hivemind
|
|
|
|
|
|
+from hivemind.dht import DHT, DHTID
|
|
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
|
|
from hivemind.client.averaging.allreduce import AllReduceRunner, AllreduceException, GroupID, split_into_parts
|
|
|
|
+from hivemind.client.averaging.load_balancing import load_balance_peers
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.client.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
|
|
+from hivemind.client.averaging.group_info import GroupInfo
|
|
from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
|
|
from hivemind.proto import averaging_pb2, averaging_pb2_grpc, runtime_pb2
|
|
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
|
|
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
|
|
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
|
|
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
-from hivemind.utils import Endpoint, Port, MPFuture, get_logger
|
|
|
|
|
|
+from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescriptor
|
|
|
|
|
|
# flavour types
|
|
# flavour types
|
|
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
StreamCallToLeader = grpc.aio.UnaryStreamCall[averaging_pb2.JoinRequest, averaging_pb2.MessageFromLeader]
|
|
@@ -85,7 +88,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
_pending_group_assembled: asyncio.Event
|
|
_pending_group_assembled: asyncio.Event
|
|
serializer = MSGPackSerializer
|
|
serializer = MSGPackSerializer
|
|
|
|
|
|
- def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
|
|
|
|
|
|
+ def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool,
|
|
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
|
|
prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
|
|
averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
|
|
averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
|
|
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
|
|
allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
|
|
@@ -112,12 +115,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
for tensor in self._averaged_tensors:
|
|
for tensor in self._averaged_tensors:
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
|
|
tensor.share_memory_()
|
|
tensor.share_memory_()
|
|
|
|
+ self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
|
|
+ self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
|
|
+ self._throughput = throughput
|
|
|
|
|
|
self.matchmaking_kwargs = dict(
|
|
self.matchmaking_kwargs = dict(
|
|
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
|
|
prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
|
|
- min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
|
|
|
|
- chunk_size_bytes=chunk_size_bytes, compression_type=compression_type,
|
|
|
|
- throughput=throughput, min_vector_size=min_vector_size)
|
|
|
|
|
|
+ min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout)
|
|
|
|
+ self.allreduce_kwargs = dict(compression_type=compression_type, chunk_size_bytes=chunk_size_bytes,
|
|
|
|
+ min_vector_size=min_vector_size)
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
|
|
|
|
@@ -170,8 +176,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
else:
|
|
else:
|
|
logger.info(f"The averager running in an experimental client mode, please report any bugs.")
|
|
logger.info(f"The averager running in an experimental client mode, please report any bugs.")
|
|
|
|
|
|
- self._matchmaking = Matchmaking(self.endpoint, self._averaged_tensors, self.dht, **self.matchmaking_kwargs,
|
|
|
|
- client_mode=not self.listen, return_deltas=True)
|
|
|
|
|
|
+ self._matchmaking = Matchmaking(self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs,
|
|
|
|
+ client_mode=not self.listen)
|
|
if self.listen:
|
|
if self.listen:
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
@@ -207,26 +213,29 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
if self._parent_pid != os.getpid() or self.is_alive():
|
|
if self._parent_pid != os.getpid() or self.is_alive():
|
|
self.shutdown()
|
|
self.shutdown()
|
|
|
|
|
|
- def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,
|
|
|
|
- wait=True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
|
|
|
|
|
|
+ def step(self, gather: Optional[DataForGather] = None, weight: float = 1.0, timeout: Optional[float] = None,
|
|
|
|
+ allow_retries: bool = True, wait: bool = True) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
|
|
"""
|
|
"""
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
|
|
|
|
- :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
|
|
- within the specified timeout
|
|
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
|
|
+ :param weight: averaging weight for this peer, int or float, must be strictly positive
|
|
|
|
+ :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
|
|
+ within the specified timeout
|
|
:param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
|
|
:param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
|
|
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
:param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
"""
|
|
"""
|
|
|
|
+ assert isinstance(weight, (int, float)) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
|
|
future, _future = MPFuture.make_pair()
|
|
future, _future = MPFuture.make_pair()
|
|
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
gather_binary = self.serializer.dumps(gather) # serialize here to avoid loading modules in the averager process
|
|
- self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary,
|
|
|
|
|
|
+ self.pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
|
|
allow_retries=allow_retries, timeout=timeout)))
|
|
allow_retries=allow_retries, timeout=timeout)))
|
|
return future.result() if wait else future
|
|
return future.result() if wait else future
|
|
|
|
|
|
- async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]):
|
|
|
|
|
|
+ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
|
|
|
|
+ allow_retries: bool, timeout: Optional[float]):
|
|
loop = asyncio.get_event_loop()
|
|
loop = asyncio.get_event_loop()
|
|
start_time = get_dht_time()
|
|
start_time = get_dht_time()
|
|
group_id = None
|
|
group_id = None
|
|
@@ -234,28 +243,28 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
while not future.done():
|
|
while not future.done():
|
|
try:
|
|
try:
|
|
self._pending_group_assembled.clear()
|
|
self._pending_group_assembled.clear()
|
|
- allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
|
|
|
|
- if allreduce_group is None:
|
|
|
|
|
|
+ data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
|
|
|
|
+ group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
|
|
|
|
+ if group_info is None:
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
-
|
|
|
|
- group_id = allreduce_group.group_id
|
|
|
|
- self._running_groups[group_id] = allreduce_group
|
|
|
|
|
|
+ group_id = group_info.group_id
|
|
|
|
+ allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
|
|
|
|
+ self._running_groups[group_id] = allreduce_runner
|
|
self._pending_group_assembled.set()
|
|
self._pending_group_assembled.set()
|
|
- await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
|
|
|
|
- await loop.run_in_executor(None, self.update_tensors, allreduce_group)
|
|
|
|
|
|
+ await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
|
|
|
|
+ await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
|
|
|
# averaging is finished, exit the loop
|
|
# averaging is finished, exit the loop
|
|
- gathered_items = map(self.serializer.loads, allreduce_group.gathered)
|
|
|
|
- gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
|
|
|
|
- future.set_result(gathered_data_by_peer)
|
|
|
|
|
|
+ future.set_result(allreduce_runner.gathered)
|
|
|
|
|
|
- except (AllreduceException, MatchmakingException, asyncio.InvalidStateError,
|
|
|
|
- grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
|
|
|
|
|
|
+ except (AllreduceException, MatchmakingException, AssertionError,
|
|
|
|
+ asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
|
|
time_elapsed = get_dht_time() - start_time
|
|
time_elapsed = get_dht_time() - start_time
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
|
+ logger.warning(f"Averager caught {e}")
|
|
future.set_result(None)
|
|
future.set_result(None)
|
|
else:
|
|
else:
|
|
- logger.debug(f"caught {e}, retrying")
|
|
|
|
|
|
+ logger.warning(f"Averager caught {e}, retrying")
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
@@ -264,6 +273,23 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
_ = self._running_groups.pop(group_id, None)
|
|
_ = self._running_groups.pop(group_id, None)
|
|
self._pending_group_assembled.set()
|
|
self._pending_group_assembled.set()
|
|
|
|
|
|
|
|
+ async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
|
|
+ """ Use a group description found by Matchmaking to form AllreduceRunner """
|
|
|
|
+ try:
|
|
|
|
+ weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
|
+ user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
|
|
|
|
+
|
|
|
|
+ # compute optimal part sizes from peer throughputs
|
|
|
|
+ incoming_throughputs = [thr if listen else 0.0 for thr, listen in zip(throughputs, modes)]
|
|
|
|
+ part_sizes = await asyncio.get_event_loop().run_in_executor(
|
|
|
|
+ None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size)
|
|
|
|
+ async with self.get_tensors_async() as averaged_tensors:
|
|
|
|
+ return AllReduceRunner(group_id=group_info.group_id, tensors=averaged_tensors, endpoint=self.endpoint,
|
|
|
|
+ ordered_group_endpoints=group_info.endpoints, part_sizes=part_sizes,
|
|
|
|
+ weights=weights, gathered=user_gathered, return_deltas=True, **kwargs)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ raise MatchmakingException(f"Unable to create allreduce runner ({e}), group_info: {group_info}")
|
|
|
|
+
|
|
def update_tensors(self, allreduce_group: AllReduceRunner):
|
|
def update_tensors(self, allreduce_group: AllReduceRunner):
|
|
"""
|
|
"""
|
|
a private (extendable) method that applies changes from a finished allreduce to local tensors
|
|
a private (extendable) method that applies changes from a finished allreduce to local tensors
|
|
@@ -288,6 +314,15 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
yield self._averaged_tensors
|
|
yield self._averaged_tensors
|
|
self.last_updated = get_dht_time()
|
|
self.last_updated = get_dht_time()
|
|
|
|
|
|
|
|
+ @contextlib.asynccontextmanager
|
|
|
|
+ async def get_tensors_async(self) -> Sequence[torch.Tensor]:
|
|
|
|
+ """ Like get_tensors, but uses an asynchronous contextmanager """
|
|
|
|
+ try:
|
|
|
|
+ await asyncio.get_event_loop().run_in_executor(None, self.lock_averaged_tensors.acquire)
|
|
|
|
+ yield self._averaged_tensors
|
|
|
|
+ finally:
|
|
|
|
+ self.lock_averaged_tensors.release()
|
|
|
|
+
|
|
async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
async def rpc_join_group(self, request: averaging_pb2.JoinRequest, context: grpc.ServicerContext
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
|
|
""" accept or reject a join request from another averager; if accepted, run him through allreduce steps """
|
|
""" accept or reject a join request from another averager; if accepted, run him through allreduce steps """
|
|
@@ -478,3 +513,11 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
logger.warning(e)
|
|
logger.warning(e)
|
|
continue
|
|
continue
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def compute_schema_hash(tensors: Sequence[torch.Tensor]) -> bytes:
|
|
|
|
+ """ A hash that describes follower's tensor shapes, dtypes, devices, but not the actual values """
|
|
|
|
+ schema_dicts = [{field_name: str(field_value)
|
|
|
|
+ for field_name, field_value in asdict(TensorDescriptor.from_tensor(tensor)).items()}
|
|
|
|
+ for tensor in tensors]
|
|
|
|
+ return DHTID.generate(source=schema_dicts).to_bytes()
|