justheuristic il y a 3 ans
Parent
commit
fd7cb17f55
1 fichiers modifiés avec 4 ajouts et 1 suppressions
  1. 4 1
      hivemind/averaging/averager.py

+ 4 - 1
hivemind/averaging/averager.py

@@ -464,7 +464,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         step.set_result(
                         step.set_result(
                             await asyncio.wait_for(
                             await asyncio.wait_for(
                                 self._run_allreduce(
                                 self._run_allreduce(
-                                    group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
                                 ),
                                 ),
                                 timeout=self._allreduce_timeout,
                                 timeout=self._allreduce_timeout,
                             )
                             )