Browse Source

Add connect_timeout (#423)

Alexander Borzunov 2 years ago
parent
commit
44fefa5e54

+ 1 - 1
src/petals/client/inference_session.py

@@ -75,7 +75,7 @@ class _ServerInferenceSession:
         inputs_queue = asyncio.Queue()
         outputs_stream = await asyncio.wait_for(
             stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
-            config.request_timeout,
+            config.connect_timeout,
         )
         return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
 

+ 15 - 14
src/petals/client/remote_forward_backward.py

@@ -13,52 +13,53 @@ from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
+from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.data_structures import ModuleUID, RPCInfo
 
 
 async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
 ) -> List[torch.Tensor]:
     outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
-        timeout=timeout,
+        timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in outputs.tensors]
 
 
 async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
 ) -> List[torch.Tensor]:
     grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
         runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
-        timeout=timeout,
+        timeout=config.request_timeout,
     )
     return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
 
 
 async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
         runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
-    outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
-    outputs = aiter_with_timeout(outputs, timeout)
+    outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), config.connect_timeout)
+    outputs = aiter_with_timeout(outputs, config.request_timeout)
     return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
 
 
 async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: SequenceManagerConfig, **kwargs
 ) -> List[torch.Tensor]:
     parts = (
         runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
         for tensor in serialized_tensors
         for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
     )
-    grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
-    grad_inputs = aiter_with_timeout(grad_inputs, timeout)
+    grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), config.connect_timeout)
+    grad_inputs = aiter_with_timeout(grad_inputs, config.request_timeout)
     return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
 
 
@@ -67,7 +68,7 @@ async def run_remote_forward(
     stub: StubBase,
     rpc_info: RPCInfo,
     *inputs: torch.Tensor,
-    timeout: float,
+    config: SequenceManagerConfig,
     metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Tuple[torch.Tensor, ...]:
@@ -110,7 +111,7 @@ async def run_remote_forward(
     size = sum(t.element_size() * t.nelement() for t in inputs)
     forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
@@ -121,7 +122,7 @@ async def run_remote_backward(
     inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
-    timeout: float,
+    config: SequenceManagerConfig,
     metadata: Optional[bytes] = None,
     **kwargs,
 ) -> Sequence[torch.Tensor]:
@@ -153,5 +154,5 @@ async def run_remote_backward(
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
     backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary
     # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space
-    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, timeout, metadata=metadata, **kwargs)
+    deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs)
     return deserialized_grad_inputs

+ 1 - 0
src/petals/client/routing/sequence_manager.py

@@ -40,6 +40,7 @@ class SequenceManagerConfig:
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     use_server_to_server: bool = True  # Use direct server-to-server communication
 
+    connect_timeout: float = 5  # timeout for opening a connection
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     update_period: float = 60  # refresh DHT information once in this many seconds
 

+ 2 - 2
src/petals/client/sequential_autograd.py

@@ -76,7 +76,7 @@ async def sequential_forward(
                     stub,
                     sequence_manager.rpc_info,
                     *inputs_and_prompts,
-                    timeout=sequence_manager.config.request_timeout,
+                    config=sequence_manager.config,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
 
@@ -161,7 +161,7 @@ async def sequential_backward(
                     inputs,
                     grad_outputs,
                     prompts[span.start : span.end],
-                    timeout=sequence_manager.config.request_timeout,
+                    config=sequence_manager.config,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
                 grad_outputs = [grad_outputs]