|
@@ -8,32 +8,32 @@ from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProt
|
|
|
|
|
|
|
|
|
class ExpertBackend(nn.Module):
|
|
|
+ """
|
|
|
+ ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with TesseractRuntime
|
|
|
+ By default, ExpertBackend handles three types of requests:
|
|
|
+
|
|
|
+ - forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
|
|
|
+ - backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
|
|
|
+ - get_info - return expert metadata. Not batched.
|
|
|
+
|
|
|
+ :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
|
|
|
+
|
|
|
+ - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
|
|
|
+ - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
|
|
|
+ - We recommend using experts that are ~invariant to the order in which they process batches
|
|
|
+
|
|
|
+ :param opt: torch optimizer to be applied on every backward call
|
|
|
+ :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
|
|
|
+ :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
|
|
|
+ :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
|
|
|
+ :param kwargs: extra parameters to be forwarded into TaskPool.__init__
|
|
|
+ """
|
|
|
+
|
|
|
def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
|
|
|
args_schema: Tuple[BatchTensorProto, ...] = None,
|
|
|
kwargs_schema: Dict[str, BatchTensorProto] = None,
|
|
|
outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
|
|
|
**kwargs):
|
|
|
- """
|
|
|
- ExpertBackend implements how a given expert processes tasks.
|
|
|
- By default, there are two tasks:
|
|
|
- * forward receives inputs and produces outputs
|
|
|
- * backward receives gradients w.r.t. outputs, computes gradients w.r.t. inputs and trains the expert
|
|
|
-
|
|
|
- All incoming tasks are grouped by type (forward/backward) and sent into the corresponding pool,
|
|
|
- where tasks are grouped into minibatches and prepared for processing on device;
|
|
|
- The results are dispatched to task authors with SharedFuture.set_result.
|
|
|
-
|
|
|
- :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
|
|
|
- * Experts must always receive the same set of *args and **kwargs and produce output tensors of same type
|
|
|
- * All *args, **kwargs and outputs must be *tensors* where 0-th dimension represents to batch size
|
|
|
- * We recommend using experts that are ~invariant to the order in which they process batches
|
|
|
-
|
|
|
- :param opt: torch optimizer to be applied on every backward call
|
|
|
- :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
|
|
|
- :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
|
|
|
- :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
|
|
|
- :param kwargs: extra parameters to be forwarded into TaskPool.__init__
|
|
|
- """
|
|
|
super().__init__()
|
|
|
self.expert, self.opt, self.name = expert, opt, name
|
|
|
|
|
@@ -56,15 +56,43 @@ class ExpertBackend(nn.Module):
|
|
|
self.backward_pool = TaskPool(self.backward, uid=f'{self.name}_backward', **kwargs)
|
|
|
|
|
|
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
+ """
|
|
|
+ Apply forward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually;
|
|
|
+ To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
|
|
|
+
|
|
|
+ Subclassing:
|
|
|
+ This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
|
|
|
+
|
|
|
+ It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
|
|
|
+
|
|
|
+ .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
|
|
|
+
|
|
|
+ """
|
|
|
args, kwargs = nested_pack(inputs, structure=self.forward_schema)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.expert(*args, **kwargs)
|
|
|
|
|
|
- # Note: TaskPool requires function to accept and return a **list** of values, we pack/unpack it on client side
|
|
|
+ # Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
|
|
|
return tuple(nested_flatten(outputs))
|
|
|
|
|
|
def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
+ """
|
|
|
+ Apply backward pass to an aggregated batch of requests. Used by TesseractRuntime, do not call this manually
|
|
|
+ To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
|
|
|
+
|
|
|
+ Subclassing:
|
|
|
+ This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
|
|
|
+
|
|
|
+ It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
|
|
|
+
|
|
|
+ TesseractRuntime doesn't guarantee that backward will be performed in the same order and for the same data
|
|
|
+ as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
|
|
|
+
|
|
|
+ .. todo state, randomness, etc
|
|
|
+
|
|
|
+ Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
|
|
|
+ """
|
|
|
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
|
|
|
|
|
|
with torch.enable_grad():
|
|
@@ -87,12 +115,18 @@ class ExpertBackend(nn.Module):
|
|
|
for x in nested_flatten((args, kwargs)))
|
|
|
|
|
|
def apply_gradients(self) -> None:
|
|
|
+ """
|
|
|
+ Train the expert for a single step. This method is called by ``ExpertBackend.backward`` after computing gradients.
|
|
|
+ """
|
|
|
self.opt.step()
|
|
|
self.opt.zero_grad()
|
|
|
|
|
|
- def get_pools(self) -> Sequence[TaskPool]:
|
|
|
- return self.forward_pool, self.backward_pool
|
|
|
-
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
|
+ """ Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration. """
|
|
|
return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema,
|
|
|
keyword_names=tuple(self.kwargs_schema.keys()))
|
|
|
+
|
|
|
+ def get_pools(self) -> Sequence[TaskPool]:
|
|
|
+ """ return all pools that should be processed by ``TesseractRuntime`` """
|
|
|
+ return self.forward_pool, self.backward_pool
|
|
|
+
|