|
@@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
|
|
"""
|
|
"""
|
|
import asyncio
|
|
import asyncio
|
|
import itertools
|
|
import itertools
|
|
-import logging
|
|
|
|
|
|
+from collections import deque
|
|
from typing import List, Optional, Sequence, Tuple
|
|
from typing import List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
|
|
|
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
|
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
@@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
|
|
from src.server.handler import TransformerConnectionHandler
|
|
from src.server.handler import TransformerConnectionHandler
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
+logger = get_logger(__file__)
|
|
|
|
+
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
|
|
|
|
|
@@ -39,16 +42,25 @@ async def sequential_forward(
|
|
sequence_manager.block_uids
|
|
sequence_manager.block_uids
|
|
) # should be n_layers - 1 but add extra prompts for convenience
|
|
) # should be n_layers - 1 but add extra prompts for convenience
|
|
|
|
|
|
- sequences = sequence_manager.make_sequence(start_index, end_index)
|
|
|
|
|
|
+ sequences = deque()
|
|
intermediate_inputs = []
|
|
intermediate_inputs = []
|
|
done_sequences = []
|
|
done_sequences = []
|
|
outputs = inputs
|
|
outputs = inputs
|
|
|
|
|
|
- while len(sequences) > 0:
|
|
|
|
|
|
+ block_idx = start_index
|
|
|
|
+ while block_idx < len(sequence_manager):
|
|
for attempt_no in itertools.count():
|
|
for attempt_no in itertools.count():
|
|
- span = sequences.pop(0)
|
|
|
|
- span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
|
|
+ logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
|
|
try:
|
|
try:
|
|
|
|
+ if attempt_no >= 1:
|
|
|
|
+ sequence_manager.update_()
|
|
|
|
+ if not sequences or attempt_no >= 1:
|
|
|
|
+ sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
|
|
|
|
+ logger.debug(f"Found path from block {block_idx} via {len(sequences)} servers")
|
|
|
|
+
|
|
|
|
+ span = sequences.popleft()
|
|
|
|
+ span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
|
|
|
|
+
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
|
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
|
|
|
|
|
|
@@ -64,14 +76,16 @@ async def sequential_forward(
|
|
done_sequences.append(span)
|
|
done_sequences.append(span)
|
|
|
|
|
|
inputs = outputs
|
|
inputs = outputs
|
|
|
|
+ block_idx = span.end
|
|
break
|
|
break
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
|
- await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
|
|
|
|
-
|
|
|
|
- backup_sequences = sequence_manager.make_sequence(span.start)
|
|
|
|
- assert backup_sequences[0].start == span.start
|
|
|
|
- sequences = backup_sequences
|
|
|
|
|
|
+ delay = sequence_manager.min_backoff * 2**attempt_no
|
|
|
|
+ logger.warning(
|
|
|
|
+ f"Caught exception when running forward from block {block_idx} "
|
|
|
|
+ f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
|
+ )
|
|
|
|
+ logger.debug("See detailed traceback below:", exc_info=True)
|
|
|
|
+ await asyncio.sleep(delay)
|
|
|
|
|
|
return outputs, intermediate_inputs, done_sequences
|
|
return outputs, intermediate_inputs, done_sequences
|
|
|
|
|
|
@@ -110,7 +124,7 @@ async def sequential_backward(
|
|
grad_prompts_reversed.extend(span_grad_prompts)
|
|
grad_prompts_reversed.extend(span_grad_prompts)
|
|
break
|
|
break
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
|
|
|
|
+ logger.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
|
|
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
|
|
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
|
|
|
|
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|
|
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
|