Explorar el Código

Support bfloat16 for autograd (#499)

Dmitry Baranchuk hace 3 años
padre
commit
28261470e4
Se han modificado 1 ficheros con 1 adiciones y 3 borrados
  1. 1 3
      hivemind/moe/server/module_backend.py

+ 1 - 3
hivemind/moe/server/module_backend.py

@@ -118,9 +118,7 @@ class ModuleBackend:
 
         with torch.enable_grad():
             args = [
-                tensor.detach().requires_grad_(True)
-                if tensor.dtype in (torch.half, torch.float, torch.double)
-                else tensor.detach()
+                tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
                 for tensor in args
             ]
             kwargs = {