Procházet zdrojové kódy

Support bfloat16 for autograd (#499)

Dmitry Baranchuk před 3 roky
rodič
revize
28261470e4
1 změnil soubory, kde provedl 1 přidání a 3 odebrání
  1. 1 3
      hivemind/moe/server/module_backend.py

+ 1 - 3
hivemind/moe/server/module_backend.py

@@ -118,9 +118,7 @@ class ModuleBackend:
 
         with torch.enable_grad():
             args = [
-                tensor.detach().requires_grad_(True)
-                if tensor.dtype in (torch.half, torch.float, torch.double)
-                else tensor.detach()
+                tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
                 for tensor in args
             ]
             kwargs = {