Forráskód Böngészése

unironiously black

justheuristic 3 éve
szülő
commit
aa2b70eb34

+ 7 - 4
hivemind/averaging/averager.py

@@ -422,10 +422,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         await step.wait_for_trigger()
                     return group_info
                 except asyncio.CancelledError:
-                    await asyncio.wait({
-                        self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                        for peer_id in group_info.peer_ids if peer_id != self.peer_id
-                    })
+                    await asyncio.wait(
+                        {
+                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
+                            for peer_id in group_info.peer_ids
+                            if peer_id != self.peer_id
+                        }
+                    )
                     raise
 
             while not step.done():

+ 5 - 3
hivemind/optim/experimental/optimizer.py

@@ -396,9 +396,11 @@ class Optimizer(torch.optim.Optimizer):
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
             swarm_not_empty = self.tracker.global_progress.num_peers > 1
             should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
-            should_average_state = (swarm_not_empty and
-                                    next_epoch % self.average_state_every == 0 and
-                                    not self.state_averager.averaging_in_progress)
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
 
             if should_average_state and self.scheduled_state is not None:
                 if self.scheduled_state.triggered or self.scheduled_state.done():