소스 검색

Fix random freezes in averager.step, improve error handling (#254)

- fix a heisenbug where DecentralizedAverager would randomly hang on pytorch ops: pytorch/pytorch#17199
- tweak "sanity check failed" clause in DecentralizedAverager._step so that it is no longer triggered by sanctioned retries
- tweak Matchmaking.request_join_group to handle RPC errors instead of cancelling the entire matchmaking

Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 년 전
부모
커밋
94b9db0d37
2개의 변경된 파일52개의 추가작업 그리고 33개의 파일을 삭제
  1. 47 33
      hivemind/client/averaging/__init__.py
  2. 5 0
      hivemind/client/averaging/matchmaking.py

+ 47 - 33
hivemind/client/averaging/__init__.py

@@ -158,6 +158,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return f"{self.__class__.__name__}({self.endpoint})"
         return f"{self.__class__.__name__}({self.endpoint})"
 
 
     def run(self):
     def run(self):
+        """
+        Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
+        Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
+        Read more: https://github.com/pytorch/pytorch/issues/17199
+        """
+        thread = threading.Thread(target=self._run_internal, daemon=True)
+        thread.start()
+        thread.join()
+
+    def _run_internal(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
@@ -240,41 +250,45 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         start_time = get_dht_time()
         start_time = get_dht_time()
         group_id = None
         group_id = None
 
 
-        while not future.done():
-            try:
-                self._pending_group_assembled.clear()
-                data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
-                if group_info is None:
-                    raise AllreduceException("Averaging step failed: could not find a group.")
-                group_id = group_info.group_id
-                allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                self._running_groups[group_id] = allreduce_runner
-                self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
-
-                # averaging is finished, exit the loop
-                future.set_result(allreduce_runner.gathered)
-
-            except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
-                    asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
-                time_elapsed = get_dht_time() - start_time
-                if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    logger.exception(f"Averager caught {repr(e)}")
-                    future.set_exception(e)
-                else:
-                    logger.warning(f"Averager caught {repr(e)}, retrying")
+        try:
+            while not future.done():
+                try:
+                    self._pending_group_assembled.clear()
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    if group_info is None:
+                        raise AllreduceException("Averaging step failed: could not find a group.")
+                    group_id = group_info.group_id
+                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
+                    self._running_groups[group_id] = allreduce_runner
+                    self._pending_group_assembled.set()
+                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
+                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+
+                    # averaging is finished, exit the loop
+                    future.set_result(allreduce_runner.gathered)
+
+                except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
+                        asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
+                    time_elapsed = get_dht_time() - start_time
+                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                        logger.exception(f"Averager caught {repr(e)}")
+                        future.set_exception(e)
+                    else:
+                        logger.warning(f"Averager caught {repr(e)}, retrying")
 
 
-            except BaseException as e:
+                finally:
+                    _ = self._running_groups.pop(group_id, None)
+                    self._pending_group_assembled.set()
+
+        except BaseException as e:
+            if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
-                raise
-            finally:
-                _ = self._running_groups.pop(group_id, None)
-                self._pending_group_assembled.set()
-                if not future.done():
-                    future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
-                                                      " Please report this to hivemind issues."))
+            raise
+        finally:
+            if not future.done():
+                future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
+                                                  " Please report this to hivemind issues."))
 
 
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         """ Use a group description found by Matchmaking to form AllreduceRunner """

+ 5 - 0
hivemind/client/averaging/matchmaking.py

@@ -10,6 +10,7 @@ import concurrent.futures
 import asyncio
 import asyncio
 
 
 import grpc
 import grpc
+import grpc._cython.cygrpc
 
 
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
@@ -199,6 +200,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if call is not None:
             if call is not None:
                 call.cancel()
                 call.cancel()
             return None
             return None
+        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            return None
+
         finally:
         finally:
             self.was_accepted_to_group.clear()
             self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None