|
@@ -1,4 +1,5 @@
|
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
+from contextlib import nullcontext
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from torch import nn
|
|
from torch import nn
|
|
@@ -48,6 +49,7 @@ class ExpertBackend:
|
|
optimizer: torch.optim.Optimizer,
|
|
optimizer: torch.optim.Optimizer,
|
|
*,
|
|
*,
|
|
device: torch.device,
|
|
device: torch.device,
|
|
|
|
+ fp16: bool = False,
|
|
scheduler: Callable = None,
|
|
scheduler: Callable = None,
|
|
args_schema: Tuple[BatchTensorDescriptor, ...] = None,
|
|
args_schema: Tuple[BatchTensorDescriptor, ...] = None,
|
|
kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
|
|
kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
|
|
@@ -61,6 +63,7 @@ class ExpertBackend:
|
|
self.expert = expert.to(device)
|
|
self.expert = expert.to(device)
|
|
self.optimizer, self.name = optimizer, name
|
|
self.optimizer, self.name = optimizer, name
|
|
self.device = device
|
|
self.device = device
|
|
|
|
+ self.fp16 = fp16
|
|
|
|
|
|
if scheduler is None:
|
|
if scheduler is None:
|
|
self.scheduler = None
|
|
self.scheduler = None
|
|
@@ -115,7 +118,7 @@ class ExpertBackend:
|
|
if args[0].shape[0] == 0:
|
|
if args[0].shape[0] == 0:
|
|
raise RuntimeError("Batch should contain more than 0 samples")
|
|
raise RuntimeError("Batch should contain more than 0 samples")
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
|
|
|
+ with torch.no_grad(), torch.cuda.amp.autocast() if self.fp16 else nullcontext():
|
|
outputs = self.expert(*args, **kwargs)
|
|
outputs = self.expert(*args, **kwargs)
|
|
|
|
|
|
# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
|
|
# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
|
|
@@ -140,7 +143,7 @@ class ExpertBackend:
|
|
"""
|
|
"""
|
|
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
|
|
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
|
|
|
|
|
|
- with torch.enable_grad():
|
|
|
|
|
|
+ with torch.enable_grad(), torch.cuda.amp.autocast() if self.fp16 else nullcontext():
|
|
args = [
|
|
args = [
|
|
tensor.detach().requires_grad_(True)
|
|
tensor.detach().requires_grad_(True)
|
|
if tensor.dtype in (torch.half, torch.float, torch.double)
|
|
if tensor.dtype in (torch.half, torch.float, torch.double)
|