Michael Diskin 4 年之前
父节点
当前提交
1d1496329d
共有 2 个文件被更改,包括 5 次插入2 次删除
  1. 4 0
      hivemind/averaging/averager.py
  2. 1 2
      hivemind/optim/collaborative.py

+ 4 - 0
hivemind/averaging/averager.py

@@ -228,21 +228,25 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         client_mode=self.client_mode,
                         **self.matchmaking_kwargs,
                     )
+                    logger.debug(f"The 1")
                     if not self.client_mode:
                         asyncio.create_task(self._declare_for_download_periodically())
 
                     self._pending_group_assembled = asyncio.Event()
                     self._pending_group_assembled.set()
+                    logger.debug(f"The 2")
                 except Exception as e:
                     # Loglevel is DEBUG since normally the exception is propagated to the caller
                     logger.debug(e, exc_info=True)
                     self._ready.set_exception(e)
                     return
                 self._ready.set_result(None)
+                logger.debug(f"The 3")
 
                 while True:
                     try:
                         method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
+                        logger.debug(f"The 4")
                     except (OSError, ConnectionError) as e:
                         logger.exception(e)
                         await asyncio.sleep(self._matchmaking.request_timeout)

+ 1 - 2
hivemind/optim/collaborative.py

@@ -240,8 +240,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             current_step, group_info = self.averager.local_step, None
 
             if self.collaboration_state.num_peers > 1:
-                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
-                weight = self.local_samples_accumulated / mean_samples_per_worker
+                weight = self.local_samples_accumulated / self.target_batch_size
                 try:
                     group_info = self.averager.step(weight=weight, timeout=self.averaging_timeout, **kwargs)
                     if group_info: