justheuristic 3 سال پیش
والد
کامیت
513c0e9da3
1فایلهای تغییر یافته به همراه5 افزوده شده و 5 حذف شده
  1. 5 5
      hivemind/averaging/averager.py

+ 5 - 5
hivemind/averaging/averager.py

@@ -455,7 +455,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                     group_info = await matchmaking_task
 
-                    async with self._register_allreduce_group(group_info):
+                    with self._register_allreduce_group(group_info):
                         if group_info is None:
                             raise AllreduceException("Averaging step failed: could not find a group")
 
@@ -505,16 +505,16 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 )
 
     @contextlib.contextmanager
-    def _register_allreduce_group(self, group_id: GroupID):
+    def _register_allreduce_group(self, group_info: GroupInfo):
         """registers a given all-reduce runner to listen for incoming connections"""
         try:
-            self._running_groups[group_id] = asyncio.Future()
+            self._running_groups[group_info.group_id] = asyncio.Future()
             self._pending_groups_registered.set()
             yield
         finally:
-            maybe_future = self._running_groups.pop(group_id, None)
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
             if maybe_future and not maybe_future.done():
-                logger.warning(f"All-reduce group {group_id} did not finish.")
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
             self._pending_groups_registered.set()
 
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData: