|
@@ -109,6 +109,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
averaging_expiration: float = None,
|
|
averaging_expiration: float = None,
|
|
min_matchmaking_time: float = 5.0,
|
|
min_matchmaking_time: float = 5.0,
|
|
request_timeout: float = 3.0,
|
|
request_timeout: float = 3.0,
|
|
|
|
+ load_state_message_timeout: float = 10.0,
|
|
averaging_alpha: float = 1.0,
|
|
averaging_alpha: float = 1.0,
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
allreduce_timeout: Optional[float] = None,
|
|
allreduce_timeout: Optional[float] = None,
|
|
@@ -164,6 +165,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
tensor.share_memory_()
|
|
tensor.share_memory_()
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
self.schema_hash = compute_schema_hash(self._averaged_tensors)
|
|
|
|
+ self.load_state_message_timeout = load_state_message_timeout
|
|
self.shutdown_timeout = shutdown_timeout
|
|
self.shutdown_timeout = shutdown_timeout
|
|
self.bandwidth = bandwidth
|
|
self.bandwidth = bandwidth
|
|
|
|
|
|
@@ -643,7 +645,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
current_tensor_parts, tensors = [], []
|
|
current_tensor_parts, tensors = [], []
|
|
|
|
|
|
- async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
|
|
|
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=self.load_state_message_timeout):
|
|
if message.metadata:
|
|
if message.metadata:
|
|
metadata = self.serializer.loads(message.metadata)
|
|
metadata = self.serializer.loads(message.metadata)
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
if message.tensor_part.dtype and current_tensor_parts:
|