justheuristic 3 лет назад
Родитель
Сommit
7a0bc0246f
1 измененных файлов с 7 добавлено и 0 удалено
  1. 7 0
      hivemind/optim/experimental/state_averager.py

+ 7 - 0
hivemind/optim/experimental/state_averager.py

@@ -152,6 +152,13 @@ class TrainingStateAverager(DecentralizedAverager):
         parameter_names = tuple(nested_flatten(parameter_names))
         assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
+        params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
+        if params_no_grad >= params_with_grad:
+            logging.warning("The majority of parameters have requires_grad=False, but they are still synchronized"
+                            " with peers. If these parameters are frozen (not updated), please do not feed them into "
+                            "the optimizer at all in order to avoid communication overhead.")
+
         return param_groups, parameters, parameter_names
 
     def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):