ソースを参照

Merge remote-tracking branch 'origin/power_ef_new' into power_ef_new

Artem Chumachenko 3 年 前
コミット
c33f843e0f
1 ファイル変更4 行追加4 行削除
  1. 4 4
      hivemind/optim/experimental/power_ef_averager.py

+ 4 - 4
hivemind/optim/experimental/power_ef_averager.py

@@ -122,7 +122,7 @@ class PowerEFGradientAverager(GradientAverager):
             )
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                compressed_tensors = [lt for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
+                compressed_tensors = [lt.to("cpu") for idx, lt in enumerate(local_tensors) if idx not in self._uncompressed_gradients]
                 cs = [torch.zeros_like(grad, device="cpu") for grad in compressed_tensors]
                 for c, g, cg in zip(cs, self._gs, compressed_tensors):
                     torch.sub(cg, g, out=c)
@@ -191,7 +191,7 @@ class PowerEFGradientAverager(GradientAverager):
                     c.copy_(new_c.reshape(c.size()))
 
                 for c, g in zip(cs, self._gs):
-                    torch.add(g, c * 0.9, out=g)
+                    torch.add(g, c, out=g)
 
                 return allreduce1.gathered
         except BaseException as e:
@@ -214,11 +214,11 @@ class PowerEFGradientAverager(GradientAverager):
                 assert len(averaged_grads) == len(self.parameters)
                 old_grads = [param.grad for param in self.parameters]
                 for param, new_grad in zip(self.parameters, averaged_grads):
-                    param.grad = new_grad
+                    param.grad.copy_(new_grad)
                 yield
             finally:
                 for param, old_grad in zip(self.parameters, old_grads):
-                    param.grad = old_grad
+                    param.grad.copy_(old_grad)
             for cg, oag in zip(compressed_tensors, old_averaged):
                 cg.copy_(oag)