Преглед на файлове

Fix random freezes in averager.step, improve error handling (#254)

- fix a heisenbug where DecentralizedAverager would randomly hang on pytorch ops: pytorch/pytorch#17199
- tweak "sanity check failed" clause in DecentralizedAverager._step so that it is no longer triggered by sanctioned retries
- tweak Matchmaking.request_join_group to handle RPC errors instead of cancelling the entire matchmaking

Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic преди 4 години
родител
ревизия
94b9db0d37
променени са 2 файла, в които са добавени 52 реда и са изтрити 33 реда
  1. 47 33
      hivemind/client/averaging/__init__.py
  2. 5 0
      hivemind/client/averaging/matchmaking.py

+ 47 - 33
hivemind/client/averaging/__init__.py

@@ -158,6 +158,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         return f"{self.__class__.__name__}({self.endpoint})"
         return f"{self.__class__.__name__}({self.endpoint})"
 
 
     def run(self):
     def run(self):
+        """
+        Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
+        Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
+        Read more: https://github.com/pytorch/pytorch/issues/17199
+        """
+        thread = threading.Thread(target=self._run_internal, daemon=True)
+        thread.start()
+        thread.join()
+
+    def _run_internal(self):
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         """ Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
         loop = switch_to_uvloop()
         loop = switch_to_uvloop()
         # initialize asyncio synchronization primitives in this event loop
         # initialize asyncio synchronization primitives in this event loop
@@ -240,41 +250,45 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         start_time = get_dht_time()
         start_time = get_dht_time()
         group_id = None
         group_id = None
 
 
-        while not future.done():
-            try:
-                self._pending_group_assembled.clear()
-                data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
-                group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
-                if group_info is None:
-                    raise AllreduceException("Averaging step failed: could not find a group.")
-                group_id = group_info.group_id
-                allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
-                self._running_groups[group_id] = allreduce_runner
-                self._pending_group_assembled.set()
-                await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
-                await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
-
-                # averaging is finished, exit the loop
-                future.set_result(allreduce_runner.gathered)
-
-            except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
-                    asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
-                time_elapsed = get_dht_time() - start_time
-                if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                    logger.exception(f"Averager caught {repr(e)}")
-                    future.set_exception(e)
-                else:
-                    logger.warning(f"Averager caught {repr(e)}, retrying")
+        try:
+            while not future.done():
+                try:
+                    self._pending_group_assembled.clear()
+                    data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
+                    group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
+                    if group_info is None:
+                        raise AllreduceException("Averaging step failed: could not find a group.")
+                    group_id = group_info.group_id
+                    allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
+                    self._running_groups[group_id] = allreduce_runner
+                    self._pending_group_assembled.set()
+                    await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
+                    await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
+
+                    # averaging is finished, exit the loop
+                    future.set_result(allreduce_runner.gathered)
+
+                except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
+                        asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
+                    time_elapsed = get_dht_time() - start_time
+                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
+                        logger.exception(f"Averager caught {repr(e)}")
+                        future.set_exception(e)
+                    else:
+                        logger.warning(f"Averager caught {repr(e)}, retrying")
 
 
-            except BaseException as e:
+                finally:
+                    _ = self._running_groups.pop(group_id, None)
+                    self._pending_group_assembled.set()
+
+        except BaseException as e:
+            if not future.done():
                 future.set_exception(e)
                 future.set_exception(e)
-                raise
-            finally:
-                _ = self._running_groups.pop(group_id, None)
-                self._pending_group_assembled.set()
-                if not future.done():
-                    future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
-                                                      " Please report this to hivemind issues."))
+            raise
+        finally:
+            if not future.done():
+                future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
+                                                  " Please report this to hivemind issues."))
 
 
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
     async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
         """ Use a group description found by Matchmaking to form AllreduceRunner """
         """ Use a group description found by Matchmaking to form AllreduceRunner """

+ 5 - 0
hivemind/client/averaging/matchmaking.py

@@ -10,6 +10,7 @@ import concurrent.futures
 import asyncio
 import asyncio
 
 
 import grpc
 import grpc
+import grpc._cython.cygrpc
 
 
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.group_info import GroupInfo
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
 from hivemind.client.averaging.key_manager import GroupKeyManager, GroupKey
@@ -199,6 +200,10 @@ class Matchmaking(averaging_pb2_grpc.DecentralizedAveragingServicer):
             if call is not None:
             if call is not None:
                 call.cancel()
                 call.cancel()
             return None
             return None
+        except (grpc.RpcError, grpc.aio.AioRpcError, grpc._cython.cygrpc.InternalError, StopAsyncIteration) as e:
+            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            return None
+
         finally:
         finally:
             self.was_accepted_to_group.clear()
             self.was_accepted_to_group.clear()
             self.current_leader = None
             self.current_leader = None