Bladeren bron

cover edge case

justheuristic 2 jaren geleden
bovenliggende
commit
6675d255bd
3 gewijzigde bestanden met toevoegingen van 4 en 4 verwijderingen
  1. 1 1
      src/client/remote_forward_backward.py
  2. 1 2
      src/client/sequence_manager.py
  3. 2 1
      src/client/sequential_autograd.py

+ 1 - 1
src/client/remote_forward_backward.py

@@ -5,7 +5,7 @@ import asyncio
 from typing import Iterable, List, Sequence, Tuple
 
 import torch
-from hivemind import nested_compare, nested_flatten, serialize_torch_tensor, nested_pack
+from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
 from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
 from hivemind.p2p import StubBase
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE

+ 1 - 2
src/client/sequence_manager.py

@@ -9,7 +9,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src import NoSpendingPolicy
+from src.client.spending_policy import NoSpendingPolicy
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.server.handler import TransformerConnectionHandler
@@ -31,7 +31,6 @@ class RemoteSequenceManager:
         self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
         self.last_update_time: DHTExpiration = -float("inf")
-        self.spending_policy = NoSpendingPolicy()
         self.max_retries = max_retries
         self._rpc_info = None
         self.lock_changes = threading.Lock()

+ 2 - 1
src/client/sequential_autograd.py

@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
 import torch
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 
-from src.client.remote_forward_backward import run_remote_forward, run_remote_backward
+from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from src.server.handler import TransformerConnectionHandler
@@ -41,6 +41,7 @@ async def sequential_forward(
     sequences = sequence_manager.make_sequence(start_index, end_index)
     intermediate_inputs = []
     done_sequences = []
+    outputs = inputs
 
     while len(sequences) > 0:
         while True: