Browse Source

Remove AMP, update lr

Max Ryabinin 3 years ago
parent
commit
0ff0c689e8
2 changed files with 1 additions and 3 deletions
  1. 1 1
      hivemind/moe/server/__init__.py
  2. 0 2
      hivemind/moe/server/expert_backend.py

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -255,7 +255,7 @@ class Server(threading.Thread):
             optim = OffloadOptimizer(
             optim = OffloadOptimizer(
                 optimizer_grouped_parameters,
                 optimizer_grouped_parameters,
                 optim_cls=LambWithGradientClipping,
                 optim_cls=LambWithGradientClipping,
-                lr=0.00176,
+                lr=0.0035355339059327377,
                 betas=(0.9, 0.999),
                 betas=(0.9, 0.999),
                 eps=1e-6,
                 eps=1e-6,
                 weight_decay=0.01,
                 weight_decay=0.01,

+ 0 - 2
hivemind/moe/server/expert_backend.py

@@ -96,7 +96,6 @@ class ExpertBackend:
         self.update_count = 0
         self.update_count = 0
         self.examples_processed = 0
         self.examples_processed = 0
 
 
-    @torch.cuda.amp.autocast()
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -122,7 +121,6 @@ class ExpertBackend:
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
         return tuple(nested_flatten(outputs))
 
 
-    @torch.cuda.amp.autocast()
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually