|
@@ -31,7 +31,7 @@ from hivemind.utils import Endpoint, Port, MPFuture, get_logger, TensorDescripto
|
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
|
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, split_for_streaming, combine_from_streaming
|
|
|
-from hivemind.utils.networking import choose_ip_address, strip_port
|
|
|
+from hivemind.utils.networking import choose_ip_address, strip_port, Hostname
|
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
|
|
|
@@ -64,11 +64,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
|
:note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
|
|
|
:param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
|
|
|
- :param throughput: if specified, this value represents the network bandwidth available to averager.
|
|
|
+ :param bandwidth: if specified, this value represents the network bandwidth available to averager.
|
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
|
- If throughput == 0, averager will rely on its groupmates to do all the averaging.
|
|
|
- :param listen: if True (default), this averager will accept incoming requests from other peers and perform allreduce
|
|
|
- if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
|
|
|
+ If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
|
|
|
+ :param client_mode: if False (default), this averager will accept incoming requests from other peers
|
|
|
+ if True, the averager will only join existing groups where at least one peer has client_mode=False
|
|
|
:param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
|
|
|
:param announced_host: visible IP address the averager will announce for external connections from other peers.
|
|
|
If None, the address will be chosen from p2p.get_visible_maddrs() (global IPv4 addresses are preferred)
|
|
@@ -115,11 +115,11 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
|
allreduce_timeout: Optional[float] = None,
|
|
|
compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
|
- throughput: Optional[float] = None,
|
|
|
+ bandwidth: Optional[float] = None,
|
|
|
min_vector_size: int = 0,
|
|
|
auxiliary: bool = False,
|
|
|
allow_state_sharing: Optional[bool] = None,
|
|
|
- listen: bool = True,
|
|
|
+ client_mode: bool = False,
|
|
|
listen_on: Endpoint = "0.0.0.0:*",
|
|
|
daemon: bool = True,
|
|
|
announced_host: Optional[str] = None,
|
|
@@ -128,18 +128,19 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
**kwargs,
|
|
|
):
|
|
|
assert "." not in prefix, "group prefix must be a string without trailing '.'"
|
|
|
- assert throughput is None or (
|
|
|
- throughput >= 0 and np.isfinite(np.float32(throughput))
|
|
|
- ), "throughput must be a non-negative float32"
|
|
|
+ assert bandwidth is None or (
|
|
|
+ bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
|
|
|
+ ), "bandwidth must be a non-negative float32"
|
|
|
if not is_power_of_two(target_group_size):
|
|
|
logger.warning("It is recommended to set target_group_size to a power of 2.")
|
|
|
assert initial_group_bits is None or all(bit in "01" for bit in initial_group_bits)
|
|
|
- assert listen or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
+ assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
- self.listen, self.listen_on, self.kwargs = listen, listen_on, kwargs
|
|
|
- if not self.listen:
|
|
|
+ self.client_mode, self.listen_on, self.kwargs = client_mode, listen_on, kwargs
|
|
|
+ self._parent_pid = os.getpid()
|
|
|
+ if self.client_mode:
|
|
|
self.mode = AveragingMode.CLIENT
|
|
|
elif auxiliary:
|
|
|
self.mode = AveragingMode.AUX
|
|
@@ -161,7 +162,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
|
- self.throughput = throughput
|
|
|
+ self.bandwidth = bandwidth
|
|
|
|
|
|
self.matchmaking_kwargs = dict(
|
|
|
prefix=prefix,
|
|
@@ -181,10 +182,12 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._port = mp.Value(ctypes.c_uint32, 0) # assigned when averager starts, accessible via self.port
|
|
|
|
|
|
self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
|
|
|
- self.allow_state_sharing = (listen and not auxiliary) if allow_state_sharing is None else allow_state_sharing
|
|
|
+ if allow_state_sharing is None:
|
|
|
+ allow_state_sharing = not client_mode and not auxiliary
|
|
|
+ self.allow_state_sharing = allow_state_sharing
|
|
|
|
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
|
- if not self.listen:
|
|
|
+ if self.client_mode:
|
|
|
self._averager_endpoint = f"client::{uuid.uuid4()}"
|
|
|
|
|
|
self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
@@ -221,16 +224,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
|
|
|
@allow_state_sharing.setter
|
|
|
def allow_state_sharing(self, value: bool):
|
|
|
- if value is True and not self.listen:
|
|
|
- logger.warning(
|
|
|
- "Cannot allow state sharing: averager in client mode (listen=False) cannot share its state."
|
|
|
- )
|
|
|
+ if value and self.client_mode:
|
|
|
+ raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
|
|
|
else:
|
|
|
self._allow_state_sharing.value = value
|
|
|
|
|
|
@property
|
|
|
def endpoint(self) -> Optional[Endpoint]:
|
|
|
- if self.listen and self._averager_endpoint is None:
|
|
|
+ if self._averager_endpoint is None and not self.client_mode:
|
|
|
assert self.port is not None, "Averager is not running yet"
|
|
|
self._averager_endpoint = f"{self.announced_host}:{self.port}"
|
|
|
logger.debug(f"Assuming averager endpoint to be {self._averager_endpoint}")
|
|
@@ -258,7 +259,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
async def _run():
|
|
|
grpc.aio.init_grpc_aio()
|
|
|
|
|
|
- if self.listen:
|
|
|
+ if not self.client_mode:
|
|
|
self._server = grpc.aio.server(**self.kwargs, options=GRPC_KEEPALIVE_OPTIONS)
|
|
|
averaging_pb2_grpc.add_DecentralizedAveragingServicer_to_server(self, self._server)
|
|
|
found_port = self._server.add_insecure_port(self.listen_on)
|
|
@@ -269,9 +270,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
logger.debug(f"The averager is running in client mode.")
|
|
|
|
|
|
self._matchmaking = Matchmaking(
|
|
|
- self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=not self.listen
|
|
|
+ self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
|
|
|
)
|
|
|
- if self.listen:
|
|
|
+ if not self.client_mode:
|
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
self._pending_group_assembled = asyncio.Event()
|
|
@@ -312,7 +313,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
remaining_tasks = set()
|
|
|
for group in self._running_groups.values():
|
|
|
remaining_tasks.update(group.finalize(cancel=True))
|
|
|
- if self.listen:
|
|
|
+ if not self.client_mode:
|
|
|
remaining_tasks.add(self._server.stop(timeout))
|
|
|
await asyncio.gather(*remaining_tasks)
|
|
|
|
|
@@ -374,7 +375,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
while not future.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
- data_for_gather = self.serializer.dumps([weight, self.throughput, self.mode.value, gather_binary])
|
|
|
+ data_for_gather = self.serializer.dumps([weight, self.bandwidth, self.mode.value, gather_binary])
|
|
|
group_info = await self._matchmaking.look_for_group(
|
|
|
timeout=timeout, data_for_gather=data_for_gather
|
|
|
)
|
|
@@ -422,16 +423,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
try:
|
|
|
- weights, throughputs, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
+ weights, bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
user_gathered = dict(zip(group_info.endpoints, map(self.serializer.loads, user_gathered)))
|
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
- # compute optimal part sizes from peer throughputs; TODO: replace with proper load balancing
|
|
|
- incoming_throughputs = [
|
|
|
- thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(throughputs, modes)
|
|
|
+ # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
|
+ download_bandwidths = [
|
|
|
+ thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
|
|
|
]
|
|
|
peer_fractions = await asyncio.get_event_loop().run_in_executor(
|
|
|
- None, load_balance_peers, self.total_size, incoming_throughputs, min_vector_size
|
|
|
+ None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
|
|
|
)
|
|
|
|
|
|
async with self.get_tensors_async() as local_tensors:
|