justheuristic před 3 roky
rodič
revize
b6a5757ef3
1 změnil soubory, kde provedl 13 přidání a 3 odebrání
  1. 13 3
      hivemind/optim/experimental/optimizer.py

+ 13 - 3
hivemind/optim/experimental/optimizer.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import logging
+import os
 from typing import Optional, Union
 
 import torch
@@ -34,7 +35,7 @@ class Optimizer(torch.optim.Optimizer):
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
     >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, prefix="run_42",
-    >>>                          target_batch_size=4096, average_gradients=True, batch_size_per_step=4)
+    >>>                          target_batch_size=4096, batch_size_per_step=4)
     >>> while True:
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
     >>>     opt.zero_grad()
@@ -118,6 +119,7 @@ class Optimizer(torch.optim.Optimizer):
         self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
         self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
         self._schema_hash = self._compute_schema_hash()
+        self._parent_pid = os.getpid()
 
         self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
         # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
@@ -345,7 +347,15 @@ class Optimizer(torch.optim.Optimizer):
 
     @property
     def param_groups(self) -> ParamGroups:
-        return self.state_averager.optimizer.param_groups
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index: next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
 
     def add_param_group(self, param_group: dict) -> None:
         raise ValueError(
@@ -366,5 +376,5 @@ class Optimizer(torch.optim.Optimizer):
         logger.debug(f"{self.__class__.__name__} is shut down.")
 
     def __del__(self):
-        if self.is_alive():  # TODO check os.getpid!!!
+        if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()