|
@@ -118,9 +118,7 @@ class ModuleBackend:
|
|
|
|
|
|
with torch.enable_grad():
|
|
|
args = [
|
|
|
- tensor.detach().requires_grad_(True)
|
|
|
- if tensor.dtype in (torch.half, torch.float, torch.double)
|
|
|
- else tensor.detach()
|
|
|
+ tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
|
|
|
for tensor in args
|
|
|
]
|
|
|
kwargs = {
|