Browse Source

Log collaboration step to Wandb, store metrics only if peer is synchronized (#267)

These are small practical changes moved from https://github.com/mryab/collaborative-training

Co-authored-by: Michael Diskin <yhn1124@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Aleksandr Borzunov 4 năm trước cách đây
mục cha
commit
020c068344

+ 3 - 2
examples/albert/run_first_peer.py

@@ -177,13 +177,14 @@ if __name__ == '__main__':
                     num_samples += item.samples_accumulated
                     sum_mini_steps += item.mini_steps
                 current_loss = sum_loss / sum_mini_steps
-                
+
                 if coordinator_args.wandb_project is not None:
                     wandb.log({
                         "loss": current_loss,
                         "alive peers": alive_peers,
                         "samples": num_samples,
-                        "performance": sum_perf
+                        "performance": sum_perf,
+                        "step": latest_step
                     })
                 if checkpoint_handler.is_time_to_save_state(current_step):
                     checkpoint_handler.save_state(current_step)

+ 7 - 6
examples/albert/run_trainer.py

@@ -112,7 +112,7 @@ class CollaborativeCallback(transformers.TrainerCallback):
 
     def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
                        control: transformers.TrainerControl, **kwargs):
-        logger.warning('Loading state from peers')
+        logger.info('Loading state from peers')
         self.collaborative_optimizer.load_state_from_peers()
 
     def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
@@ -139,14 +139,15 @@ class CollaborativeCallback(transformers.TrainerCallback):
                 logger.info(f"Step {self.collaborative_optimizer.local_step}")
                 logger.info(f"Your current contribution: {self.total_samples_processed} samples")
                 if self.steps:
-                    logger.info(f"Loss of your model: {self.loss/self.steps}")
+                    logger.info(f"Local loss: {self.loss / self.steps}")
 
                 self.loss = 0
                 self.steps = 0
-                self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
-                               subkey=self.local_public_key, value=statistics.dict(),
-                               expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
-                               return_future=True)
+                if self.collaborative_optimizer.is_synchronized:
+                    self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
+                                   subkey=self.local_public_key, value=statistics.dict(),
+                                   expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
+                                   return_future=True)
 
         self.samples = self.collaborative_optimizer.local_samples_accumulated
 

+ 2 - 2
hivemind/client/averaging/allreduce.py

@@ -54,7 +54,7 @@ class AllReduceProtocol:
         self.averaged_part: asyncio.Future[torch.Tensor] = asyncio.Future()  # will be set to [accumulator / group size]
         self.averaged_tensor_parts: Dict[Endpoint, torch.Tensor] = {}  # averaged chunks from all peers will be put here
         self.future: asyncio.Future[Sequence[torch.Tensor]] = asyncio.Future()  # final result or exception
-        
+
         self.num_senders = len([mode for mode in modes if mode != AveragingMode.AUX])
 
         if self.num_senders == 0:
@@ -258,7 +258,7 @@ class AllReduceRunner(AllReduceProtocol, averaging_pb2_grpc.DecentralizedAveragi
             yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
 
 
-def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int]) -> Tuple[torch.Tensor, ...]:
+def split_into_parts(tensors: Sequence[torch.Tensor], part_sizes: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
     """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
     flat_tensor = torch.cat(tuple(map(torch.Tensor.flatten, tensors)))
     return torch.split_with_sizes(flat_tensor, part_sizes, dim=0)