Browse Source

multigroup

justheuristic 3 năm trước cách đây
mục cha
commit
2cd92dcbe7
2 tập tin đã thay đổi với 57 bổ sung49 xóa
  1. 45 40
      hivemind/averaging/averager.py
  2. 12 9
      tests/test_allreduce_fault_tolerance.py

+ 45 - 40
hivemind/averaging/averager.py

@@ -109,7 +109,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
 
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     serializer = MSGPackSerializer
@@ -207,7 +207,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
@@ -309,8 +309,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
 
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
@@ -441,7 +441,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             while not step.done():
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -455,20 +455,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     group_info = await matchmaking_task
 
-                    if group_info is None:
-                        raise AllreduceException("Averaging step failed: could not find a group")
+                    async with self._register_allreduce_group(group_info):
+                        if group_info is None:
+                            raise AllreduceException("Averaging step failed: could not find a group")
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
 
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._run_allreduce(
+                                    group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
                 except (
                     AllreduceException,
@@ -503,6 +504,19 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_id: GroupID):
+        """registers a given all-reduce runner to listen for incoming connections"""
+        try:
+            self._running_groups[group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_id, None)
+            if maybe_future and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_id} did not finish.")
+            self._pending_groups_registered.set()
+
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
@@ -531,18 +545,19 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     modes=modes,
                     **kwargs,
                 )
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # ^--- maybe this can be extracted into a method that checks if register_... context is active.
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e:
@@ -550,17 +565,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
-
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
         """
@@ -586,13 +590,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 

+ 12 - 9
tests/test_allreduce_fault_tolerance.py

@@ -66,16 +66,19 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     **kwargs,
                 )
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # ^--- maybe this can be extracted into a method that checks if register_... context is active.
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
                 return allreduce.gathered
         except BaseException as e: