Artem Chumachenko 3 年 前
コミット
847506a9d0

+ 2 - 2
hivemind/optim/experimental/power_ef_averager.py

@@ -62,11 +62,11 @@ class PowerEFGradientAverager(GradientAverager):
         self.parameters = tuple(parameters)
         self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
         self._gs = list(
-            torch.zeros_like(grad, device=accumulate_grads_on)
+            torch.zeros_like(grad, device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
         )
         self._qs = list(
-            torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device=accumulate_grads_on)
+            torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
         )
         for tensor in (self._qs + self._gs):

+ 2 - 10
hivemind/optim/experimental/power_sgd_averager.py

@@ -65,18 +65,10 @@ class PowerSGDGradientAverager(PowerEFGradientAverager):
         self._local_ef = averager_local_ef
         self._uncompressed_gradients = set(i for i, grad in enumerate(self._grads_from_parameters()) if len(tuple(grad.size())) == 1)
         self._ms = list(
-            torch.zeros_like(grad, device=accumulate_grads_on)
+            torch.zeros_like(grad, device="cpu")
             for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
         )
-        self._gs = list(
-            torch.zeros_like(grad, device=accumulate_grads_on)
-            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
-        )
-        self._qs = list(
-            torch.rand((grad.reshape((grad.size(0), -1)).size(1), self.rank), device=accumulate_grads_on)
-            for idx, grad in enumerate(self._grads_from_parameters()) if idx not in self._uncompressed_gradients
-        )
-        for tensor in (self._qs + self._gs):
+        for tensor in (self._ms):
             if tensor is not None:
                 assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
                 tensor.share_memory_()