|
@@ -21,6 +21,9 @@ class ExpertBackend(nn.Module):
|
|
|
- Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
|
|
|
- All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
|
|
|
- We recommend using experts that are ~invariant to the order in which they process batches
|
|
|
+ - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want to ensure consistency,
|
|
|
+ you should explicitly register these random variables as model outputs, so that they are sent back to the client.
|
|
|
+ See hivemind.utils.custom_layers.DeterministicDropout for an example
|
|
|
|
|
|
:param opt: torch optimizer to be applied on every backward call
|
|
|
:param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
|
|
@@ -65,7 +68,8 @@ class ExpertBackend(nn.Module):
|
|
|
|
|
|
It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
|
|
|
|
|
|
- .. todo state - we recommend stateless but you can save state if you want. disable batchnorm track running stats
|
|
|
+ .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
|
|
|
+ .. For now, either register all buffers as outputs or avoid stateful experts
|
|
|
|
|
|
"""
|
|
|
args, kwargs = nested_pack(inputs, structure=self.forward_schema)
|
|
@@ -89,15 +93,17 @@ class ExpertBackend(nn.Module):
|
|
|
Runtime doesn't guarantee that backward will be performed in the same order and for the same data
|
|
|
as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
|
|
|
|
|
|
- .. todo state, randomness, etc
|
|
|
+ .. todo correct state handling (see forward)
|
|
|
|
|
|
Please make sure to call ``ExpertBackend.apply_gradients`` **within** this method, otherwise the expert will not train
|
|
|
"""
|
|
|
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
|
|
|
|
|
|
with torch.enable_grad():
|
|
|
- args = [tensor.detach().requires_grad_(True) for tensor in args]
|
|
|
- kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()}
|
|
|
+ args = [tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
|
|
|
+ else tensor.detach() for tensor in args]
|
|
|
+ kwargs = {input_key: (tensor.detach().requires_grad_(True) if tensor.dtype in (torch.half, torch.float, torch.double)
|
|
|
+ else tensor.detach()) for input_key, tensor in kwargs.items()}
|
|
|
|
|
|
outputs = self.expert(*args, **kwargs)
|
|
|
assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
|
|
@@ -129,4 +135,3 @@ class ExpertBackend(nn.Module):
|
|
|
def get_pools(self) -> Sequence[TaskPool]:
|
|
|
""" return all pools that should be processed by ``Runtime`` """
|
|
|
return self.forward_pool, self.backward_pool
|
|
|
-
|