|
@@ -54,8 +54,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
other peers have already made some progress and changed their learning rate accordingly.
|
|
|
|
|
|
TODO yozh, the doc below still needs update
|
|
|
- #TODO forward timeout to state averager
|
|
|
- #TODO option to offload optimizer and DPU
|
|
|
|
|
|
:param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
|
|
|
:param dht: a running hivemind.DHT daemon connected to other peers
|
|
@@ -71,7 +69,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
|
|
|
:param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
|
|
|
:param kwargs: additional parameters forwarded to DecentralizedAverager
|
|
|
- :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
|
|
|
+ :note: If you are using ColloptaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
|
|
|
explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
|
|
|
|
|
|
|
|
@@ -206,7 +204,6 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
return
|
|
|
|
|
|
if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
|
|
|
- raise NotImplementedError()
|
|
|
logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
|
|
|
self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
|
|
|
self.grad_averager.reset_accumulated_grads_()
|