Prechádzať zdrojové kódy

Remove AMP, update lr

Max Ryabinin 3 rokov pred
rodič
commit
0ff0c689e8

+ 1 - 1
hivemind/moe/server/__init__.py

@@ -255,7 +255,7 @@ class Server(threading.Thread):
             optim = OffloadOptimizer(
                 optimizer_grouped_parameters,
                 optim_cls=LambWithGradientClipping,
-                lr=0.00176,
+                lr=0.0035355339059327377,
                 betas=(0.9, 0.999),
                 eps=1e-6,
                 weight_decay=0.01,

+ 0 - 2
hivemind/moe/server/expert_backend.py

@@ -96,7 +96,6 @@ class ExpertBackend:
         self.update_count = 0
         self.examples_processed = 0
 
-    @torch.cuda.amp.autocast()
     def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
@@ -122,7 +121,6 @@ class ExpertBackend:
         # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
         return tuple(nested_flatten(outputs))
 
-    @torch.cuda.amp.autocast()
     def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
         """
         Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually