|
@@ -23,7 +23,7 @@ from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
|
|
|
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
|
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
|
-from hivemind.utils.serializer import MSGPackSerializer
|
|
|
+from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
|
|
|
from hivemind.utils import Endpoint, Port, MPFuture, get_logger
|
|
|
|
|
|
# flavour types
|
|
@@ -126,8 +126,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
self._averager_endpoint: Optional[Endpoint] = None
|
|
|
self.ready = mp.Event() # whether the averager process has started (and ready for incoming requests)
|
|
|
# note: we create a background thread weakref and with daemon=True to ensure garbage collection
|
|
|
- background_fetcher = threading.Thread(daemon=True, target=_background_thread_fetch_current_state,
|
|
|
- args=[self.pipe, weakref.WeakMethod(self.get_current_state)])
|
|
|
+ background_fetcher = threading.Thread(
|
|
|
+ daemon=True, target=_background_thread_fetch_current_state,
|
|
|
+ args=[self.serializer, self.pipe, weakref.WeakMethod(self.get_current_state)])
|
|
|
background_fetcher.start()
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
@@ -326,13 +327,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
else:
|
|
|
yield averaging_pb2.DownloadData(tensor_part=part)
|
|
|
|
|
|
- def get_current_state(self) -> Tuple[bytes, Sequence[torch.Tensor]]:
|
|
|
+ def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
|
|
|
"""
|
|
|
Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
|
- :returns: a tuple of (serialized_metadata, sequence of torch tensors)
|
|
|
+ :returns: a tuple of (small metadata, sequence of torch tensors)
|
|
|
+ :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
|
|
|
"""
|
|
|
with self.get_tensors() as tensors:
|
|
|
- return self.serializer.dumps(dict(group_key=self.get_group_bits())), tensors
|
|
|
+ return dict(group_key=self.get_group_bits()), tensors
|
|
|
|
|
|
async def _get_current_state_from_host_process(self):
|
|
|
""" Executed in the averager process inside rpc_download_state """
|
|
@@ -433,9 +435,11 @@ def is_power_of_two(n):
|
|
|
return (n != 0) and (n & (n - 1) == 0)
|
|
|
|
|
|
|
|
|
-def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod):
|
|
|
+def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.connection.Connection,
|
|
|
+ get_current_state_ref: weakref.WeakMethod):
|
|
|
"""
|
|
|
Executed in the host process as a background thread. Fetches the averager state when asked by peers.
|
|
|
+ :param serializer: a serializer with which to convert metadata into bytes
|
|
|
:param pipe: DecentralizedAverager's control pipe (from host process side)
|
|
|
:param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
|
|
|
"""
|
|
@@ -452,7 +456,7 @@ def _background_thread_fetch_current_state(pipe: mp.connection.Connection, get_c
|
|
|
state_metadata, state_tensors = get_current_state()
|
|
|
del get_current_state
|
|
|
|
|
|
- assert isinstance(state_metadata, bytes)
|
|
|
+ state_metadata = serializer.dumps(state_metadata)
|
|
|
state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
|
|
|
for tensor in state_tensors)
|
|
|
# note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|