Ver Fonte

Fix random freezes in averager.step, improve error handling (#254)

- fix a heisenbug where DecentralizedAverager would randomly hang on pytorch ops: pytorch/pytorch#17199
- tweak "sanity check failed" clause in DecentralizedAverager._step so that it is no longer triggered by sanctioned retries
- tweak Matchmaking.request_join_group to handle RPC errors instead of cancelling the entire matchmaking

Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic há 4 anos atrás
pai
commit
94b9db0d37

+ 47 - 33
hivemind/client/averaging/__init__.py

@@ -158,6 +158,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return f"{self.__class__.__name__}({self.endpoint})"
 
     def run(self):
+        """
+        Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
+        Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
+        Read more: https://github.com/pytorch/pytorch/issues/17199
+        """
+        thread = threading.Thread(target=self._run_internal, daemon=True)
+        thread.start()
+        thread.join()
+
+    def _run_internal(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
@@ -240,41 +250,45 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         start_time = get_dht_time()
         group_id = None
 
-        while not future.done():
-            try:
-                self._pending_group_assembled.clear()
-                data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
-                if group_info is None:
-                    raise AllreduceException("Averaging step failed: could not find a group.")
-                group_id = group_info.group_id
-                allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                self._running_groups[group_id] = allreduce_runner
-                self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
-
-                # averaging is finished, exit the loop
-                future.set_result(allreduce_runner.gathered)
-
-            except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
-                    asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
-                time_elapsed = get_dht_time() - start_time
-                if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    logger.exception(f"Averager caught {repr(e)}")
-                    future.set_exception(e)
-                else:
-                    logger.warning(f"Averager caught {repr(e)}, retrying")
+        try:
+            while not future.done():
+                try:
+                    self._pending_group_assembled.clear()
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    if group_info is None:
+                        raise AllreduceException("Averaging step failed: could not find a group.")
+                    group_id = group_info.group_id
+                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
+                    self._running_groups[group_id] = allreduce_runner
+                    self._pending_group_assembled.set()
+                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
+                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+
+                    # averaging is finished, exit the loop
+                    future.set_result(allreduce_runner.gathered)
+
+                except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
+                        asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
+                    time_elapsed = get_dht_time() - start_time
+                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                        logger.exception(f"Averager caught {repr(e)}")
+                        future.set_exception(e)
+                    else:
+                        logger.warning(f"Averager caught {repr(e)}, retrying")
 
-            except BaseException as e:
+                finally:
+                    _ = self._running_groups.pop(group_id, None)
+                    self._pending_group_assembled.set()
+
+        except BaseException as e:
+            if not future.done():
                 future.set_exception(e)
-                raise
-            finally:
-                _ = self._running_groups.pop(group_id, None)
-                self._pending_group_assembled.set()
-                if not future.done():
-                    future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
-                                                      " Please report this to hivemind issues."))
+            raise
+        finally:
+            if not future.done():
+                future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
+                                                  " Please report this to hivemind issues."))
 
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """

+ 5 - 0
hivemind/client/averaging/matchmaking.py

@@ -10,6 +10,7 @@ import concurrent.futures
 import asyncio
 
 import grpc
+import grpc._cython.cygrpc
 
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
@@ -199,6 +200,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if call is not None:
                 call.cancel()
             return None
+        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            return None
+
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None