Quellcode durchsuchen

Support bfloat16 for autograd (#499)

Dmitry Baranchuk vor 3 Jahren
Ursprung
Commit
28261470e4
1 geänderte Dateien mit 1 neuen und 3 gelöschten Zeilen
  1. 1 3
      hivemind/moe/server/module_backend.py

+ 1 - 3
hivemind/moe/server/module_backend.py

@@ -118,9 +118,7 @@ class ModuleBackend:
 
 
         with torch.enable_grad():
         with torch.enable_grad():
             args = [
             args = [
-                tensor.detach().requires_grad_(True)
-                if tensor.dtype in (torch.half, torch.float, torch.double)
-                else tensor.detach()
+                tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
                 for tensor in args
                 for tensor in args
             ]
             ]
             kwargs = {
             kwargs = {