|
@@ -121,6 +121,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
|
|
|
Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
|
|
|
Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
|
|
|
+ :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
|
|
|
:param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
|
|
|
:param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
|
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
@@ -173,6 +174,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
|
|
|
matchmaking_time: Optional[float] = 15.0,
|
|
|
averaging_timeout: Optional[float] = 60.0,
|
|
|
+ allreduce_timeout: Optional[float] = None,
|
|
|
load_state_timeout: float = 600.0,
|
|
|
reuse_grad_buffers: bool = False,
|
|
|
offload_optimizer: Optional[bool] = None,
|
|
@@ -197,6 +199,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
client_mode = client_mode if client_mode is None else dht.client_mode
|
|
|
delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
|
|
|
offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
|
|
|
+ allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
|
|
|
assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
|
|
|
assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
|
assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
@@ -225,8 +228,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
|
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
|
|
|
|
- self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
|
|
|
- self.shutdown_timeout = shutdown_timeout
|
|
|
+ self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
|
|
|
+ self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
|
|
|
|
|
|
self.status_loglevel = logging.INFO if verbose else logging.DEBUG
|
|
|
self.scheduled_grads: Optional[StepControl] = None
|
|
@@ -271,7 +274,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
dht=self.dht,
|
|
|
prefix=f"{self.run_id}_state_averager",
|
|
|
min_matchmaking_time=self.matchmaking_time,
|
|
|
- allreduce_timeout=self.averaging_timeout,
|
|
|
+ allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
offload_optimizer=self.offload_optimizer,
|
|
|
custom_gradients=self.offload_optimizer,
|
|
@@ -289,7 +292,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
prefix=f"{self.run_id}_grad_averager",
|
|
|
parameters=self.state_averager.main_parameters,
|
|
|
min_matchmaking_time=self.matchmaking_time,
|
|
|
- allreduce_timeout=self.averaging_timeout,
|
|
|
+ allreduce_timeout=self.allreduce_timeout,
|
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
|
client_mode=self.client_mode,
|
|
|
auxiliary=self.auxiliary,
|