|
@@ -35,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
>>> model = transformers.AutoModel("albert-xxlarge-v2")
|
|
>>> model = transformers.AutoModel("albert-xxlarge-v2")
|
|
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
|
|
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
|
|
- >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
|
|
|
|
|
|
+ >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, run_id="run_42",
|
|
>>> target_batch_size=4096, batch_size_per_step=4)
|
|
>>> target_batch_size=4096, batch_size_per_step=4)
|
|
>>> while True:
|
|
>>> while True:
|
|
>>> loss = compute_loss_on_batch(model, batch_size=4)
|
|
>>> loss = compute_loss_on_batch(model, batch_size=4)
|
|
@@ -54,7 +54,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
other peers have already made some progress and changed their learning rate accordingly.
|
|
other peers have already made some progress and changed their learning rate accordingly.
|
|
|
|
|
|
:param dht: a running hivemind.DHT instance connected to other peers
|
|
:param dht: a running hivemind.DHT instance connected to other peers
|
|
- :param prefix: a unique name of this experiment, used as a common prefix for all DHT keys
|
|
|
|
|
|
+ :param run_id: a unique name of this experiment, used as a common prefix for all DHT keys
|
|
:param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
|
|
:param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
|
|
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
:param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
|
|
:param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
|
|
:param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
|
|
@@ -88,7 +88,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
self,
|
|
self,
|
|
*,
|
|
*,
|
|
dht: DHT,
|
|
dht: DHT,
|
|
- prefix: str,
|
|
|
|
|
|
+ run_id: str,
|
|
target_batch_size: int,
|
|
target_batch_size: int,
|
|
batch_size_per_step: Optional[int] = None,
|
|
batch_size_per_step: Optional[int] = None,
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
optimizer: Union[TorchOptimizer, OptimizerFactory],
|
|
@@ -114,7 +114,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
|
|
assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
|
|
|
|
|
|
- self.dht, self.prefix, self.client_mode, self.auxiliary = dht, prefix, client_mode, auxiliary
|
|
|
|
|
|
+ self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
|
|
self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
|
|
self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
|
|
@@ -141,7 +141,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
|
|
def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
|
|
return TrainingStateAverager(
|
|
return TrainingStateAverager(
|
|
dht=self.dht,
|
|
dht=self.dht,
|
|
- prefix=f"{self.prefix}_state_averager",
|
|
|
|
|
|
+ prefix=f"{self.run_id}_state_averager",
|
|
allreduce_timeout=self.averaging_timeout,
|
|
allreduce_timeout=self.averaging_timeout,
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
status_loglevel=self.status_loglevel,
|
|
status_loglevel=self.status_loglevel,
|
|
@@ -157,7 +157,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
assert hasattr(self, "state_averager"), "must initialize state averager first"
|
|
assert hasattr(self, "state_averager"), "must initialize state averager first"
|
|
grad_averager = GradientAverager(
|
|
grad_averager = GradientAverager(
|
|
dht=self.dht,
|
|
dht=self.dht,
|
|
- prefix=f"{self.prefix}_grad_averager",
|
|
|
|
|
|
+ prefix=f"{self.run_id}_grad_averager",
|
|
parameters=self.state_averager.main_parameters,
|
|
parameters=self.state_averager.main_parameters,
|
|
allreduce_timeout=self.averaging_timeout,
|
|
allreduce_timeout=self.averaging_timeout,
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
shutdown_timeout=self.shutdown_timeout,
|
|
@@ -177,7 +177,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
|
|
def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
|
|
return ProgressTracker(
|
|
return ProgressTracker(
|
|
dht=self.dht,
|
|
dht=self.dht,
|
|
- prefix=self.prefix,
|
|
|
|
|
|
+ prefix=self.run_id,
|
|
target_batch_size=target_batch_size,
|
|
target_batch_size=target_batch_size,
|
|
client_mode=self.client_mode,
|
|
client_mode=self.client_mode,
|
|
status_loglevel=self.status_loglevel,
|
|
status_loglevel=self.status_loglevel,
|
|
@@ -444,7 +444,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
)
|
|
)
|
|
|
|
|
|
def __repr__(self):
|
|
def __repr__(self):
|
|
- return f"{self.__class__.__name__}(prefix={self.prefix}, epoch={self.local_epoch})"
|
|
|
|
|
|
+ return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
|
|
|
|
|
|
def shutdown(self):
|
|
def shutdown(self):
|
|
logger.debug("Sending goodbye to peers...")
|
|
logger.debug("Sending goodbye to peers...")
|