浏览代码

Expose request_timeout to DistributedBloomConfig (#105)

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Artem Chumachenko 2 年之前
父节点
当前提交
7d859a947b

+ 1 - 1
src/petals/client/inference_session.py

@@ -184,7 +184,7 @@ class InferenceSession:
                         stub,
                         span_uids,
                         rpc_info=self._sequence_manager.rpc_info,
-                        timeout=self._sequence_manager.timeout,
+                        timeout=self._sequence_manager.request_timeout,
                         max_length=self._max_length,
                         **self._metadata,
                     )

+ 2 - 1
src/petals/client/remote_model.py

@@ -36,6 +36,7 @@ class DistributedBloomConfig(BloomConfig):
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+    request_timeout: int = 20  # a number of seconds for waiting result from each node
 
 
 original_register_parameter = nn.Module.register_parameter
@@ -84,7 +85,7 @@ class DistributedBloomModel(BloomModel):
             else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
-        self.h = RemoteSequential(config, dht, config.dht_prefix)
+        self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
 
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)

+ 2 - 1
src/petals/client/remote_sequential.py

@@ -30,6 +30,7 @@ class RemoteSequential(nn.Module):
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
+        request_timeout: int = 20,
     ):
         super().__init__()
         self.config = config
@@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
         block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, request_timeout=request_timeout)
             self.is_subsequence = False
         else:
             logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")

+ 2 - 2
src/petals/client/sequence_manager.py

@@ -30,7 +30,7 @@ class RemoteSequenceManager:
         block_uids: Sequence[ModuleUID],
         p2p: P2P,
         max_retries: int = 3,
-        timeout: float = 20,
+        request_timeout: float = 20,
         min_backoff: float = 1,
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
@@ -41,7 +41,7 @@ class RemoteSequenceManager:
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
         self.last_update_time: DHTExpiration = -float("inf")
         self.max_retries = max_retries
-        self.timeout, self.min_backoff = timeout, min_backoff
+        self.request_timeout, self.min_backoff = request_timeout, min_backoff
         self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.policy = NoSpendingPolicy()

+ 2 - 2
src/petals/client/sequential_autograd.py

@@ -77,7 +77,7 @@ async def sequential_forward(
                     stub,
                     sequence_manager.rpc_info,
                     *inputs_and_prompts,
-                    timeout=sequence_manager.timeout,
+                    timeout=sequence_manager.request_timeout,
                     metadata=metadata,
                 )
 
@@ -161,7 +161,7 @@ async def sequential_backward(
                     inputs,
                     grad_outputs,
                     prompts[span.start : span.end],
-                    timeout=sequence_manager.timeout,
+                    timeout=sequence_manager.request_timeout,
                     metadata=metadata,
                 )
                 grad_outputs = [grad_outputs]