Prechádzať zdrojové kódy

Implement exponential backoff for forward & backward (#85)

Alexander Borzunov 2 rokov pred
rodič
commit
57e8d2e721
2 zmenil súbory, kde vykonal 10 pridanie a 3 odobranie
  1. 1 1
      requirements.txt
  2. 9 2
      src/client/sequential_autograd.py

+ 1 - 1
requirements.txt

@@ -4,5 +4,5 @@ accelerate==0.10.0
 huggingface-hub==0.7.0
 transformers==4.21.3
 protobuf>=3.12.2,<4.0.0
-hivemind==1.1.2
+git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
 humanfriendly

+ 9 - 2
src/client/sequential_autograd.py

@@ -2,6 +2,7 @@
 A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
 """
 import asyncio
+import itertools
 import logging
 from typing import List, Optional, Sequence, Tuple
 
@@ -23,6 +24,7 @@ async def sequential_forward(
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     end_index: Optional[int] = None,
+    min_backoff: float = 1.0,
 ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
     """
     Constructs a routing path from <start_index> to <end_index>.
@@ -44,7 +46,7 @@ async def sequential_forward(
     outputs = inputs
 
     while len(sequences) > 0:
-        while True:
+        for attempt_no in itertools.count():
             span = sequences.pop(0)
             span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
             try:
@@ -64,6 +66,8 @@ async def sequential_forward(
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
+                await asyncio.sleep(min_backoff * 2**attempt_no)
+
                 backup_sequences = sequence_manager.make_sequence(span.start)
                 assert backup_sequences[0].start == span.start
                 sequences = backup_sequences
@@ -77,6 +81,7 @@ async def sequential_backward(
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
+    min_backoff: float = 1.0,
 ) -> Sequence[torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
@@ -86,7 +91,7 @@ async def sequential_backward(
 
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
-        while True:
+        for attempt_no in itertools.count():
             inputs = intermediate_inputs.pop(-1)
             span = forward_sequences.pop(-1)
             span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
@@ -100,6 +105,8 @@ async def sequential_backward(
                 break
             except Exception as e:
                 logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
+                await asyncio.sleep(min_backoff * 2**attempt_no)
+
                 _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
                     inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
                 )