|
@@ -21,7 +21,7 @@ from hivemind.utils.grpc import ChannelCache, GRPC_KEEPALIVE_OPTIONS, \
|
|
|
serialize_torch_tensor, deserialize_torch_tensor, split_for_streaming, combine_from_streaming
|
|
|
from hivemind.utils.asyncio import anext, achain, aiter, switch_to_uvloop
|
|
|
from hivemind.utils.timed_storage import get_dht_time, ValueWithExpiration, DHTExpiration
|
|
|
-from hivemind.utils.serializer import PickleSerializer, MSGPackSerializer
|
|
|
+from hivemind.utils.serializer import MSGPackSerializer
|
|
|
from hivemind.utils import Endpoint, Port, MPFuture, get_logger
|
|
|
|
|
|
# flavour types
|
|
@@ -303,7 +303,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
Get the up-to-date trainer state from a peer.
|
|
|
The state consists of two parts: (metadata, tensors)
|
|
|
|
|
|
- - metadata is a small pickle-serialized entry meant to store scalars and hyperparameters
|
|
|
+ - metadata is a small serialized bytestring meant to store scalars and hyperparameters
|
|
|
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
|
|
|
"""
|
|
|
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
|
|
@@ -317,13 +317,13 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
else:
|
|
|
yield averaging_pb2.DownloadData(tensor_part=part)
|
|
|
|
|
|
- def get_current_state(self) -> Tuple[Any, Sequence[torch.Tensor]]:
|
|
|
+ def get_current_state(self) -> Tuple[bytes, Sequence[torch.Tensor]]:
|
|
|
"""
|
|
|
Get current state and send it to a peer. executed in the host process. Meant to be overriden.
|
|
|
- :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
|
|
|
+ :returns: a tuple of (serialized_metadata, sequence of torch tensors)
|
|
|
"""
|
|
|
with self.get_tensors() as tensors:
|
|
|
- return dict(group_key=self.get_group_bits()), tensors
|
|
|
+ return self.serializer.dumps(dict(group_key=self.get_group_bits())), tensors
|
|
|
|
|
|
async def _get_current_state_from_host_process(self):
|
|
|
""" Executed in the averager process inside rpc_download_state """
|
|
@@ -338,8 +338,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
assert trigger == '_TRIGGER_GET_CURRENT_STATE'
|
|
|
try:
|
|
|
state_metadata, state_tensors = self.get_current_state()
|
|
|
- # note: serialize here to avoid initializing cuda in the guest process
|
|
|
- state_metadata = PickleSerializer.dumps(state_metadata)
|
|
|
+ # note: we cast tensors to CPU on host side to avoid initializing cuda in the guest process
|
|
|
+ assert isinstance(state_metadata, bytes)
|
|
|
state_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad)
|
|
|
for tensor in state_tensors)
|
|
|
future.set_result((state_metadata, state_tensors))
|
|
@@ -348,8 +348,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
logger.warning(e)
|
|
|
continue
|
|
|
|
|
|
- def load_state_from_peers(self, wait=True) -> Optional[Any]:
|
|
|
- """ Try to download the latest optimizer state one of the existing peer """
|
|
|
+ def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
|
|
|
+ """
|
|
|
+ Try to download the latest optimizer state one of the existing peer.
|
|
|
+ :returns: on success, return a 2-tuple with (serialized_metadata, tensors), where
|
|
|
+
|
|
|
+ - serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters)
|
|
|
+ - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics
|
|
|
+
|
|
|
+ The exact contents of both serialized_metadata and tensors are determined by get_current_state method
|
|
|
+ """
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
|
|
|
return future.result() if wait else future
|
|
@@ -376,7 +384,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
async for message in stream:
|
|
|
if message.metadata:
|
|
|
- metadata = PickleSerializer.loads(message.metadata)
|
|
|
+ metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
|
# tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
|
|
|
tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
|
|
@@ -398,6 +406,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
future.set_result(None)
|
|
|
|
|
|
def get_group_bits(self, wait: bool = True):
|
|
|
+ """
|
|
|
+ :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
|
|
|
+ :returns: averager's current group key bits (without prefix)
|
|
|
+ """
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
self.pipe.send(('_get_group_bits', [], dict(future=_future)))
|
|
|
return future.result() if wait else future
|
|
@@ -406,6 +418,10 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
future.set_result(self._matchmaking.group_key_manager.group_bits)
|
|
|
|
|
|
def set_group_bits(self, group_bits: str, wait: bool = True):
|
|
|
+ """
|
|
|
+ :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
|
|
|
+ :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
|
|
|
+ """
|
|
|
future, _future = MPFuture.make_pair()
|
|
|
assert all(bit in '01' for bit in group_bits)
|
|
|
self.pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
|