|
@@ -455,7 +455,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
group_info = await matchmaking_task
|
|
|
|
|
|
- async with self._register_allreduce_group(group_info):
|
|
|
+ with self._register_allreduce_group(group_info):
|
|
|
if group_info is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
@@ -505,16 +505,16 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
- def _register_allreduce_group(self, group_id: GroupID):
|
|
|
+ def _register_allreduce_group(self, group_info: GroupInfo):
|
|
|
"""registers a given all-reduce runner to listen for incoming connections"""
|
|
|
try:
|
|
|
- self._running_groups[group_id] = asyncio.Future()
|
|
|
+ self._running_groups[group_info.group_id] = asyncio.Future()
|
|
|
self._pending_groups_registered.set()
|
|
|
yield
|
|
|
finally:
|
|
|
- maybe_future = self._running_groups.pop(group_id, None)
|
|
|
+ maybe_future = self._running_groups.pop(group_info.group_id, None)
|
|
|
if maybe_future and not maybe_future.done():
|
|
|
- logger.warning(f"All-reduce group {group_id} did not finish.")
|
|
|
+ logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
|
|
|
self._pending_groups_registered.set()
|
|
|
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|