|
@@ -464,7 +464,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
step.set_result(
|
|
step.set_result(
|
|
await asyncio.wait_for(
|
|
await asyncio.wait_for(
|
|
self._run_allreduce(
|
|
self._run_allreduce(
|
|
- group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
|
|
|
+ group_info,
|
|
|
|
+ tensor_infos=self.tensor_infos,
|
|
|
|
+ weight=step.weight,
|
|
|
|
+ **self.allreduce_kwargs,
|
|
),
|
|
),
|
|
timeout=self._allreduce_timeout,
|
|
timeout=self._allreduce_timeout,
|
|
)
|
|
)
|