justheuristic 3 роки тому
батько
коміт
98a6d62296
1 змінених файлів з 3 додано та 5 видалено
  1. 3 5
      hivemind/averaging/averager.py

+ 3 - 5
hivemind/averaging/averager.py

@@ -109,7 +109,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_expiration: float = None,
         min_matchmaking_time: float = 5.0,
         request_timeout: float = 3.0,
-        load_state_message_timeout: float = 10.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
@@ -165,7 +164,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
-        self.load_state_message_timeout = load_state_message_timeout
         self.shutdown_timeout = shutdown_timeout
         self.bandwidth = bandwidth
 
@@ -618,10 +616,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future = MPFuture()
-        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
         return future.result(timeout=timeout) if wait else future
 
-    async def _load_state_from_peers(self, future: MPFuture):
+    async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
@@ -645,7 +643,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
-                        async for message in aiter_with_timeout(stream, timeout=self.load_state_message_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts: