|
@@ -1,6 +1,7 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
|
+import os
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
import torch
|
|
@@ -34,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
>>> model = transformers.AutoModel("albert-xxlarge-v2")
|
|
|
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
|
|
|
>>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
|
|
|
- >>> target_batch_size=4096, average_gradients=True, batch_size_per_step=4)
|
|
|
+ >>> target_batch_size=4096, batch_size_per_step=4)
|
|
|
>>> while True:
|
|
|
>>> loss = compute_loss_on_batch(model, batch_size=4)
|
|
|
>>> opt.zero_grad()
|
|
@@ -118,6 +119,7 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
|
|
|
self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
|
|
|
self._schema_hash = self._compute_schema_hash()
|
|
|
+ self._parent_pid = os.getpid()
|
|
|
|
|
|
self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
|
|
|
# note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
|
|
@@ -345,7 +347,15 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
|
|
|
@property
|
|
|
def param_groups(self) -> ParamGroups:
|
|
|
- return self.state_averager.optimizer.param_groups
|
|
|
+ next_index = 0
|
|
|
+ param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
|
|
|
+ for param_group in param_groups:
|
|
|
+ num_params = len(param_group["params"])
|
|
|
+ main_params_for_group = self.state_averager.main_parameters[next_index: next_index + num_params]
|
|
|
+ param_group["params"] = main_params_for_group
|
|
|
+ next_index += num_params
|
|
|
+ assert next_index == len(self.state_averager.main_parameters)
|
|
|
+ return param_groups
|
|
|
|
|
|
def add_param_group(self, param_group: dict) -> None:
|
|
|
raise ValueError(
|
|
@@ -366,5 +376,5 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
logger.debug(f"{self.__class__.__name__} is shut down.")
|
|
|
|
|
|
def __del__(self):
|
|
|
- if self.is_alive(): # TODO check os.getpid!!!
|
|
|
+ if self._parent_pid == os.getpid() and self.is_alive():
|
|
|
self.shutdown()
|