justheuristic hace 3 años
padre
commit
fd7cb17f55
Se han modificado 1 ficheros con 4 adiciones y 1 borrados
  1. 4 1
      hivemind/averaging/averager.py

+ 4 - 1
hivemind/averaging/averager.py

@@ -464,7 +464,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         step.set_result(
                             await asyncio.wait_for(
                                 self._run_allreduce(
-                                    group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
                                 ),
                                 timeout=self._allreduce_timeout,
                             )