|
@@ -201,7 +201,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
client_mode = client_mode if client_mode is None else dht.client_mode
|
|
|
delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
|
|
|
- offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
|
|
|
+ if offload_optimizer is None:
|
|
|
+ offload_optimizer = params is not None and not use_local_updates
|
|
|
allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
|
|
|
next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
|
|
|
assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
|