|
@@ -109,7 +109,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
averaging_expiration: float = None,
|
|
|
min_matchmaking_time: float = 5.0,
|
|
|
request_timeout: float = 3.0,
|
|
|
- load_state_message_timeout: float = 10.0,
|
|
|
averaging_alpha: float = 1.0,
|
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
|
allreduce_timeout: Optional[float] = None,
|
|
@@ -165,7 +164,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
tensor.share_memory_()
|
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
|
- self.load_state_message_timeout = load_state_message_timeout
|
|
|
self.shutdown_timeout = shutdown_timeout
|
|
|
self.bandwidth = bandwidth
|
|
|
|
|
@@ -618,10 +616,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
The exact contents of both metadata and tensors are determined by get_current_state method
|
|
|
"""
|
|
|
future = MPFuture()
|
|
|
- self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
+ self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
|
|
|
return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
- async def _load_state_from_peers(self, future: MPFuture):
|
|
|
+ async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
|
|
|
try:
|
|
|
key_manager = self._matchmaking.group_key_manager
|
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
@@ -645,7 +643,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
|
|
|
- async for message in aiter_with_timeout(stream, timeout=self.load_state_message_timeout):
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
|
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|