|
@@ -9,7 +9,6 @@ import multiprocessing as mp
|
|
import os
|
|
import os
|
|
import threading
|
|
import threading
|
|
import weakref
|
|
import weakref
|
|
-from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
|
from dataclasses import asdict
|
|
from dataclasses import asdict
|
|
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
|
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
@@ -21,12 +20,18 @@ from hivemind.averaging.group_info import GroupInfo
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
|
|
|
|
+from hivemind.compression import (
|
|
|
|
+ CompressionBase,
|
|
|
|
+ CompressionInfo,
|
|
|
|
+ NoCompression,
|
|
|
|
+ deserialize_torch_tensor,
|
|
|
|
+ serialize_torch_tensor,
|
|
|
|
+)
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.dht import DHT, DHTID
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
-from hivemind.proto import averaging_pb2, runtime_pb2
|
|
|
|
|
|
+from hivemind.proto import averaging_pb2
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, switch_to_uvloop
|
|
-from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
from hivemind.utils.grpc import combine_from_streaming, split_for_streaming
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
from hivemind.utils.timed_storage import DHTExpiration, ValueWithExpiration, get_dht_time
|
|
@@ -51,7 +56,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
:param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
|
|
:param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
|
|
note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
|
|
note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
|
|
- :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
|
|
|
|
|
|
+ :param compression: optionally compress tensors with this compression algorithm before running all-reduce
|
|
|
|
+ :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
|
|
|
|
+ :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
|
|
:param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
|
|
:param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
@@ -102,7 +109,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
averaging_alpha: float = 1.0,
|
|
averaging_alpha: float = 1.0,
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
allreduce_timeout: Optional[float] = None,
|
|
allreduce_timeout: Optional[float] = None,
|
|
- compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
|
|
|
|
|
|
+ compression: CompressionBase = NoCompression(),
|
|
|
|
+ state_compression: CompressionBase = NoCompression(),
|
|
|
|
+ tensor_infos: Optional[Sequence[CompressionInfo]] = None,
|
|
bandwidth: Optional[float] = None,
|
|
bandwidth: Optional[float] = None,
|
|
min_vector_size: int = 0,
|
|
min_vector_size: int = 0,
|
|
auxiliary: bool = False,
|
|
auxiliary: bool = False,
|
|
@@ -158,7 +167,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
request_timeout=request_timeout,
|
|
request_timeout=request_timeout,
|
|
)
|
|
)
|
|
self.allreduce_kwargs = dict(
|
|
self.allreduce_kwargs = dict(
|
|
- compression_type=compression_type, part_size_bytes=part_size_bytes, min_vector_size=min_vector_size
|
|
|
|
|
|
+ compression=compression,
|
|
|
|
+ part_size_bytes=part_size_bytes,
|
|
|
|
+ min_vector_size=min_vector_size,
|
|
)
|
|
)
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
|
|
@@ -169,6 +180,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
if allow_state_sharing is None:
|
|
if allow_state_sharing is None:
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
self.allow_state_sharing = allow_state_sharing
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
|
+ self.state_compression = state_compression
|
|
|
|
+ self.tensor_infos = tensor_infos
|
|
|
|
|
|
self._ready = MPFuture()
|
|
self._ready = MPFuture()
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
@@ -363,7 +376,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
future.set_result(
|
|
future.set_result(
|
|
await asyncio.wait_for(
|
|
await asyncio.wait_for(
|
|
- self._run_allreduce(group_info, **self.allreduce_kwargs), self._allreduce_timeout
|
|
|
|
|
|
+ self._run_allreduce(group_info, tensor_infos=self.tensor_infos, **self.allreduce_kwargs),
|
|
|
|
+ timeout=self._allreduce_timeout,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
# averaging is finished, loop will now exit
|
|
# averaging is finished, loop will now exit
|
|
@@ -529,24 +543,27 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
"""
|
|
"""
|
|
if not self.allow_state_sharing:
|
|
if not self.allow_state_sharing:
|
|
return # deny request and direct peer to the next prospective averager
|
|
return # deny request and direct peer to the next prospective averager
|
|
- metadata, tensors = await self._get_current_state_from_host_process()
|
|
|
|
|
|
+ metadata, tensors, infos = await self._get_current_state_from_host_process()
|
|
|
|
+ if infos is None:
|
|
|
|
+ infos = [CompressionInfo.from_tensor(tensor, key=i) for i, tensor in enumerate(tensors)]
|
|
|
|
+ assert len(tensors) == len(infos)
|
|
|
|
|
|
- for tensor in tensors:
|
|
|
|
- for part in split_for_streaming(serialize_torch_tensor(tensor)):
|
|
|
|
|
|
+ for tensor, info in zip(tensors, infos):
|
|
|
|
+ for part in split_for_streaming(self.state_compression.compress(tensor, info, allow_inplace=False)):
|
|
if metadata is not None:
|
|
if metadata is not None:
|
|
yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
yield averaging_pb2.DownloadData(tensor_part=part, metadata=metadata)
|
|
metadata = None
|
|
metadata = None
|
|
else:
|
|
else:
|
|
yield averaging_pb2.DownloadData(tensor_part=part)
|
|
yield averaging_pb2.DownloadData(tensor_part=part)
|
|
|
|
|
|
- def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
|
|
|
|
|
|
+ def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor], Sequence[CompressionInfo]]:
|
|
"""
|
|
"""
|
|
Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
:returns: a tuple of (small metadata, sequence of torch tensors)
|
|
:returns: a tuple of (small metadata, sequence of torch tensors)
|
|
:note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
|
|
:note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
|
|
"""
|
|
"""
|
|
with self.get_tensors() as tensors:
|
|
with self.get_tensors() as tensors:
|
|
- return dict(group_key=self.get_group_bits()), tensors
|
|
|
|
|
|
+ return dict(group_key=self.get_group_bits()), tensors, self.tensor_infos
|
|
|
|
|
|
async def _get_current_state_from_host_process(self):
|
|
async def _get_current_state_from_host_process(self):
|
|
"""Executed in the averager process inside rpc_download_state"""
|
|
"""Executed in the averager process inside rpc_download_state"""
|
|
@@ -618,7 +635,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
finally:
|
|
finally:
|
|
if not future.done():
|
|
if not future.done():
|
|
- logger.warning("Averager could not load state from peers: none of the requests succeeded.")
|
|
|
|
future.set_result(None)
|
|
future.set_result(None)
|
|
|
|
|
|
def get_group_bits(self, wait: bool = True):
|
|
def get_group_bits(self, wait: bool = True):
|
|
@@ -681,7 +697,11 @@ def _background_thread_fetch_current_state(
|
|
get_current_state = get_current_state_ref()
|
|
get_current_state = get_current_state_ref()
|
|
if get_current_state is None:
|
|
if get_current_state is None:
|
|
break
|
|
break
|
|
- state_metadata, state_tensors = get_current_state()
|
|
|
|
|
|
+ state = get_current_state()
|
|
|
|
+ assert 0 < len(state) <= 3
|
|
|
|
+ if len(state) != 3:
|
|
|
|
+ state = tuple(state + (None,) * (3 - len(state)))
|
|
|
|
+ state_metadata, state_tensors, tensor_infos = state
|
|
del get_current_state
|
|
del get_current_state
|
|
|
|
|
|
state_metadata = serializer.dumps(state_metadata)
|
|
state_metadata = serializer.dumps(state_metadata)
|
|
@@ -689,7 +709,7 @@ def _background_thread_fetch_current_state(
|
|
tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
|
|
tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in state_tensors
|
|
)
|
|
)
|
|
# note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
# note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
- future.set_result((state_metadata, state_tensors))
|
|
|
|
|
|
+ future.set_result((state_metadata, state_tensors, tensor_infos))
|
|
except BaseException as e:
|
|
except BaseException as e:
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
logger.warning(e)
|
|
logger.warning(e)
|