|
@@ -455,9 +455,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
group_info = await matchmaking_task
|
|
|
|
|
|
+ if group_info is None:
|
|
|
+ raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
+
|
|
|
with self._register_allreduce_group(group_info):
|
|
|
- if group_info is None:
|
|
|
- raise AllreduceException("Averaging step failed: could not find a group")
|
|
|
|
|
|
step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
|