|
@@ -152,6 +152,13 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
parameter_names = tuple(nested_flatten(parameter_names))
|
|
|
assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
|
|
|
assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
|
|
|
+ params_with_grad = sum(p.numel() for p in parameters if p.requires_grad)
|
|
|
+ params_no_grad = sum(p.numel() for p in parameters if not p.requires_grad)
|
|
|
+ if params_no_grad >= params_with_grad:
|
|
|
+ logging.warning("The majority of parameters have requires_grad=False, but they are still synchronized"
|
|
|
+ " with peers. If these parameters are frozen (not updated), please do not feed them into "
|
|
|
+ "the optimizer at all in order to avoid communication overhead.")
|
|
|
+
|
|
|
return param_groups, parameters, parameter_names
|
|
|
|
|
|
def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
|