Przeglądaj źródła

make tests pass for now

justheuristic 3 lat temu
rodzic
commit
a87be914f3

+ 2 - 2
src/client/remote_sequential.py

@@ -40,7 +40,7 @@ class RemoteSequential(nn.Module):
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
         num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
-        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
+        block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
             self.sequence_manager = RemoteSequenceManager(dht, block_uids, p2p=self.p2p, start=True)
@@ -48,7 +48,7 @@ class RemoteSequential(nn.Module):
         else:
             logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
             self.sequence_manager = sequence_manager
-            assert isinstance(sequence_manager.block_uids, list)
+            assert isinstance(sequence_manager.block_uids, tuple)
             self.is_subsequence = self.sequence_manager.block_uids != block_uids
 
     def forward(self, inputs: torch.Tensor):

+ 1 - 5
src/client/routing/routing_strategy.py

@@ -8,8 +8,6 @@ from src.data_structures import RemoteSpanInfo, ServerState
 
 
 class RoutingStrategyBase(ABC):
-    name: str  # used in RemoteSequenceManager.make_sequence(mode, **kwargs)
-
     def update_(self):
         """Called when sequence manager fetches new info from the dht"""
         raise NotImplementedError()
@@ -22,8 +20,6 @@ class RoutingStrategyBase(ABC):
 class RandomRoutingStrategy(RoutingStrategyBase):
     """choose a random compatible server at each branch and include all layers served by it"""
 
-    name = "RANDOM"
-
     def __init__(self, sequence_info: RemoteSequenceInfo):
         self.sequence_info = sequence_info
         self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
@@ -79,4 +75,4 @@ class RandomRoutingStrategy(RoutingStrategyBase):
         return span_sequence
 
 
-ALL_ROUTING_STRATEGIES = (RandomRoutingStrategy,)
+ALL_ROUTING_STRATEGIES = dict(RANDOM=RandomRoutingStrategy)

+ 7 - 8
src/client/routing/sequence_manager.py

@@ -55,22 +55,22 @@ class RemoteSequenceManager(threading.Thread):
         p2p: Optional[P2P] = None,
         start: bool,
         update_period: float = 30,
-        routing_strategies: Collection[RoutingStrategyBase] = None,
+        routing_strategies: Dict[str, RoutingStrategyBase] = None,
     ):  # NB: if you add any more parameters, please make sure you pass them to sub-sequences in .__getitem__ below!
         super().__init__(daemon=True)
         self.dht, self.p2p = dht, (p2p if p2p is not None else dht.replicate_p2p())
         self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)  # to be updated in a background thread
 
         if routing_strategies is None:
-            routing_strategies = [Strategy(self.sequence_info) for Strategy in ALL_ROUTING_STRATEGIES]
-        self.routing_strategies: Dict[str, RoutingStrategyBase] = {s.name: s for s in routing_strategies}
-
+            routing_strategies = {key: Strategy(self.sequence_info) for key, Strategy in ALL_ROUTING_STRATEGIES.items()}
+        self.routing_strategies = routing_strategies
         self.last_update_time: DHTExpiration = -float("inf")
         self.update_period = update_period
 
-        self._rpc_info = None
-        self._lock_changes = threading.Lock()
+        self._rpc_info = None  # TODO move to RemoteSequenceInfo
+        self._lock_changes = threading.Lock()  # TODO move to RemoteSequenceInfo
         self.ready = threading.Event()  # whether or not you are ready to make_sequence
+        self.update_()  # TODO replace with background thread and await ready
 
         if start:
             self.run_in_background()
@@ -115,7 +115,6 @@ class RemoteSequenceManager(threading.Thread):
                 start=False,
             )  # NB: if you've added more parameters to __init__, please forward them in the instantiation above
             subseq.sequence_info = self.sequence_info[ix]
-            subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq._rpc_info = self._rpc_info
             subseq.last_update_time = self.last_update_time
             if self.is_alive():
@@ -125,7 +124,7 @@ class RemoteSequenceManager(threading.Thread):
     def update_(self):
         with self._lock_changes:
             self.sequence_info.update_(self.dht)
-            for name, strategy in self.routing_strategies:
+            for name, strategy in self.routing_strategies.items():
                 strategy.update_()
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: