|
@@ -158,6 +158,16 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
return f"{self.__class__.__name__}({self.endpoint})"
|
|
return f"{self.__class__.__name__}({self.endpoint})"
|
|
|
|
|
|
def run(self):
|
|
def run(self):
|
|
|
|
+ """
|
|
|
|
+ Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork
|
|
|
|
+ Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted
|
|
|
|
+ Read more: https://github.com/pytorch/pytorch/issues/17199
|
|
|
|
+ """
|
|
|
|
+ thread = threading.Thread(target=self._run_internal, daemon=True)
|
|
|
|
+ thread.start()
|
|
|
|
+ thread.join()
|
|
|
|
+
|
|
|
|
+ def _run_internal(self):
|
|
""" Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
|
|
""" Serve DecentralizedAverager forever. This function will not return until the averager is shut down """
|
|
loop = switch_to_uvloop()
|
|
loop = switch_to_uvloop()
|
|
# initialize asyncio synchronization primitives in this event loop
|
|
# initialize asyncio synchronization primitives in this event loop
|
|
@@ -240,41 +250,45 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
start_time = get_dht_time()
|
|
start_time = get_dht_time()
|
|
group_id = None
|
|
group_id = None
|
|
|
|
|
|
- while not future.done():
|
|
|
|
- try:
|
|
|
|
- self._pending_group_assembled.clear()
|
|
|
|
- data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
|
|
|
|
- group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
|
|
|
|
- if group_info is None:
|
|
|
|
- raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
- group_id = group_info.group_id
|
|
|
|
- allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
|
|
|
|
- self._running_groups[group_id] = allreduce_runner
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
- await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
|
|
|
|
- await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
|
-
|
|
|
|
- # averaging is finished, exit the loop
|
|
|
|
- future.set_result(allreduce_runner.gathered)
|
|
|
|
-
|
|
|
|
- except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
|
|
|
|
- asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
|
|
|
|
- time_elapsed = get_dht_time() - start_time
|
|
|
|
- if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
|
- logger.exception(f"Averager caught {repr(e)}")
|
|
|
|
- future.set_exception(e)
|
|
|
|
- else:
|
|
|
|
- logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
+ try:
|
|
|
|
+ while not future.done():
|
|
|
|
+ try:
|
|
|
|
+ self._pending_group_assembled.clear()
|
|
|
|
+ data_for_gather = self.serializer.dumps([weight, self._throughput, self.listen, gather_binary])
|
|
|
|
+ group_info = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=data_for_gather)
|
|
|
|
+ if group_info is None:
|
|
|
|
+ raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
+ group_id = group_info.group_id
|
|
|
|
+ allreduce_runner = await self._make_allreduce_runner(group_info, **self.allreduce_kwargs)
|
|
|
|
+ self._running_groups[group_id] = allreduce_runner
|
|
|
|
+ self._pending_group_assembled.set()
|
|
|
|
+ await asyncio.wait_for(allreduce_runner.run(), self._allreduce_timeout)
|
|
|
|
+ await loop.run_in_executor(None, self.update_tensors, allreduce_runner)
|
|
|
|
+
|
|
|
|
+ # averaging is finished, exit the loop
|
|
|
|
+ future.set_result(allreduce_runner.gathered)
|
|
|
|
+
|
|
|
|
+ except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
|
|
|
|
+ asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
|
|
|
|
+ time_elapsed = get_dht_time() - start_time
|
|
|
|
+ if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
|
+ logger.exception(f"Averager caught {repr(e)}")
|
|
|
|
+ future.set_exception(e)
|
|
|
|
+ else:
|
|
|
|
+ logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
- except BaseException as e:
|
|
|
|
|
|
+ finally:
|
|
|
|
+ _ = self._running_groups.pop(group_id, None)
|
|
|
|
+ self._pending_group_assembled.set()
|
|
|
|
+
|
|
|
|
+ except BaseException as e:
|
|
|
|
+ if not future.done():
|
|
future.set_exception(e)
|
|
future.set_exception(e)
|
|
- raise
|
|
|
|
- finally:
|
|
|
|
- _ = self._running_groups.pop(group_id, None)
|
|
|
|
- self._pending_group_assembled.set()
|
|
|
|
- if not future.done():
|
|
|
|
- future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
|
|
|
|
- " Please report this to hivemind issues."))
|
|
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ if not future.done():
|
|
|
|
+ future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
|
|
|
|
+ " Please report this to hivemind issues."))
|
|
|
|
|
|
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
""" Use a group description found by Matchmaking to form AllreduceRunner """
|
|
""" Use a group description found by Matchmaking to form AllreduceRunner """
|