5
0
Эх сурвалжийг харах

Implement timeouts in forward/backward (#90)

Alexander Borzunov 2 жил өмнө
parent
commit
dc6ecccac5

+ 53 - 55
src/client/remote_forward_backward.py

@@ -10,14 +10,60 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri
 from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
+from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
 from hivemind.utils.streaming import split_for_streaming
 
 from src.data_structures import ModuleUID, RPCInfo
 
 
+async def _forward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        timeout=timeout,
+    )
+    return [deserialize_torch_tensor(t) for t in outputs.tensors]
+
+
+async def _backward_unary(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
+        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
+        timeout=timeout,
+    )
+    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
+
+
+async def _forward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
+    outputs = aiter_with_timeout(outputs, timeout)
+    return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
+
+
+async def _backward_stream(
+    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
+) -> List[torch.Tensor]:
+    parts = (
+        runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
+        for tensor in serialized_tensors
+        for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
+    )
+    grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
+    grad_inputs = aiter_with_timeout(grad_inputs, timeout)
+    return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
+
+
 async def run_remote_forward(
-    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
+    uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
 ) -> Tuple[torch.Tensor, ...]:
     """
     Serializes input tensors and calls "rpc_forward" on a remote server.
@@ -57,53 +103,13 @@ async def run_remote_forward(
     # call RPC on remote server
     size = sum(t.element_size() * t.nelement() for t in inputs)
     if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
+        deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
     else:
-        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
+        deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
     return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
 
 
-async def _forward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
-
-    outputs = await stub.rpc_forward_stream(
-        amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
-            iter_as_aiter(split),
-        ),
-    )
-
-    tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
-    return await deserialize_tensor_stream(tensors_stream)
-
-
-async def _forward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
-    )
-    return [deserialize_torch_tensor(t) for t in outputs.tensors]
-
-
-async def _backward_stream(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
-
-    grad_inputs = await stub.rpc_backward_stream(
-        amap_in_executor(
-            lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
-            iter_as_aiter(split),
-        ),
-    )
-    tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
-    return await deserialize_tensor_stream(tensors_stream)
-
-
 async def run_remote_backward(
     uid: ModuleUID,
     stub: StubBase,
@@ -111,6 +117,7 @@ async def run_remote_backward(
     inputs: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
+    timeout: float,
     **kwargs,
 ) -> Sequence[torch.Tensor]:
     """
@@ -140,17 +147,8 @@ async def run_remote_backward(
 
     size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
     if size > MAX_UNARY_PAYLOAD_SIZE:
-        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
+        deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
     else:
-        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
+        deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
 
     return deserialized_grad_inputs
-
-
-async def _backward_unary(
-    uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
-) -> List[torch.Tensor]:
-    grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
-        runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
-    )
-    return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

+ 10 - 1
src/client/sequence_manager.py

@@ -24,7 +24,15 @@ class RemoteSequenceManager:
     In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
     """
 
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+    def __init__(
+        self,
+        dht: DHT,
+        block_uids: Sequence[ModuleUID],
+        p2p: P2P,
+        max_retries: int = 3,
+        timeout: float = 5,
+        min_backoff: float = 1,
+    ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         self.dht, self.p2p = dht, p2p
         self.block_uids: List[ModuleUID] = list(block_uids)
@@ -33,6 +41,7 @@ class RemoteSequenceManager:
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
         self.last_update_time: DHTExpiration = -float("inf")
         self.max_retries = max_retries
+        self.timeout, self.min_backoff = timeout, min_backoff
         self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.update_()

+ 12 - 6
src/client/sequential_autograd.py

@@ -24,7 +24,6 @@ async def sequential_forward(
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     end_index: Optional[int] = None,
-    min_backoff: float = 1.0,
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
@@ -53,7 +52,9 @@ async def sequential_forward(
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
-                (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
+                (outputs,) = await run_remote_forward(
+                    span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
+                )
 
                 assert isinstance(outputs, torch.Tensor)
                 assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@@ -66,7 +67,7 @@ async def sequential_forward(
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(min_backoff * 2**attempt_no)
+                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
 
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
@@ -81,7 +82,6 @@ async def sequential_backward(
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
-    min_backoff: float = 1.0,
 ) -> Sequence[torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
@@ -98,14 +98,20 @@ async def sequential_backward(
             try:
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
-                    span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
+                    span_uids,
+                    stub,
+                    sequence_manager.rpc_info,
+                    inputs,
+                    grad_outputs,
+                    prompts[span.start : span.end],
+                    timeout=sequence_manager.timeout,
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
-                await asyncio.sleep(min_backoff * 2**attempt_no)
+                await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
 
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
                     inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end