浏览代码

Remove AMP, update lr

Max Ryabinin 4 年之前
父节点
当前提交
0ff0c689e8
共有 2 个文件被更改,包括 1 次插入3 次删除
  1. 1 1
      hivemind/moe/server/__init__.py
  2. 0 2
      hivemind/moe/server/expert_backend.py

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -255,7 +255,7 @@ class Server(threading.Thread):
             optim = OffloadOptimizer(
             optim = OffloadOptimizer(
                 optimizer_grouped_parameters,
                 optimizer_grouped_parameters,
                 optim_cls=LambWithGradientClipping,
                 optim_cls=LambWithGradientClipping,
-                lr=0.00176,
+                lr=0.0035355339059327377,
                 betas=(0.9, 0.999),
                 betas=(0.9, 0.999),
                 eps=1e-6,
                 eps=1e-6,
                 weight_decay=0.01,
                 weight_decay=0.01,

+ 0 - 2
hivemind/moe/server/expert_backend.py

@@ -96,7 +96,6 @@ class ExpertBackend:
         self.update_count = 0
         self.update_count = 0
         self.examples_processed = 0
         self.examples_processed = 0
 
 
-    @torch.cuda.amp.autocast()
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -122,7 +121,6 @@ class ExpertBackend:
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
         return tuple(nested_flatten(outputs))
 
 
-    @torch.cuda.amp.autocast()
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually