|
@@ -1,4 +1,4 @@
|
|
|
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
|
|
|
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
@@ -8,19 +8,20 @@ from hivemind.utils.logging import get_logger
|
|
|
from hivemind.utils.nested import nested_compare, nested_flatten, nested_map, nested_pack
|
|
|
from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
|
|
|
|
|
|
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
-class ExpertBackend:
|
|
|
+class ModuleBackend:
|
|
|
"""
|
|
|
- ExpertBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
|
|
|
- By default, ExpertBackend handles three types of requests:
|
|
|
+ ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime
|
|
|
+ By default, ModuleBackend handles three types of requests:
|
|
|
|
|
|
- forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
|
|
|
- backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and **update expert**. Also batched.
|
|
|
- get_info - return expert metadata. Not batched.
|
|
|
|
|
|
- :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
|
|
|
+ :param module: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
|
|
|
|
|
|
- Experts must always receive the same set of args and kwargs and produce output tensors of same type
|
|
|
- All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
|
|
@@ -34,49 +35,37 @@ class ExpertBackend:
|
|
|
:param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto
|
|
|
:param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto
|
|
|
:param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto
|
|
|
- :param num_warmup_steps: the number of warmup steps for LR schedule
|
|
|
- :param num_total_steps: the total number of steps for LR schedule
|
|
|
- :param clip_grad_norm: maximum gradient norm used for clipping
|
|
|
:param kwargs: extra parameters to be forwarded into TaskPool.__init__
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
name: str,
|
|
|
- expert: nn.Module,
|
|
|
- optimizer: torch.optim.Optimizer,
|
|
|
+ module: nn.Module,
|
|
|
*,
|
|
|
- scheduler: Callable = None,
|
|
|
+ optimizer: Optional[torch.optim.Optimizer] = None,
|
|
|
+ scheduler: Optional[LRSchedulerBase] = None,
|
|
|
args_schema: Tuple[BatchTensorDescriptor, ...] = None,
|
|
|
kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
|
|
|
outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
|
|
|
- num_warmup_steps: int = None,
|
|
|
- num_total_steps: int = None,
|
|
|
- clip_grad_norm: float = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__()
|
|
|
- self.expert, self.optimizer, self.name = expert, optimizer, name
|
|
|
-
|
|
|
- if scheduler is None:
|
|
|
- self.scheduler = None
|
|
|
- else:
|
|
|
- assert optimizer is not None and num_warmup_steps is not None and num_total_steps is not None
|
|
|
- self.scheduler = scheduler(self.optimizer, num_warmup_steps, num_total_steps)
|
|
|
- self.clip_grad_norm = clip_grad_norm
|
|
|
+ self.name, self.module, self.optimizer, self.scheduler = name, module, optimizer, scheduler
|
|
|
|
|
|
self.args_schema = args_schema = tuple(args_schema or ())
|
|
|
self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {})
|
|
|
assert args_schema or kwargs_schema, (
|
|
|
- "expert must receive at least one positional or keyword input."
|
|
|
+ f"Module must take at least one positional or keyword input."
|
|
|
" Did you forget to provide args_schema/kwargs_schema?"
|
|
|
)
|
|
|
+ assert optimizer is not None or scheduler is None, "scheduler should only be used if optimizer is not None"
|
|
|
|
|
|
if outputs_schema is None:
|
|
|
# run expert once to get outputs schema
|
|
|
dummy_args = tuple(sample.make_zeros(DUMMY_BATCH_SIZE) for sample in args_schema)
|
|
|
dummy_kwargs = {key: sample.make_zeros(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
|
|
|
- dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
|
|
|
+ dummy_outputs = self.module(*dummy_args, **dummy_kwargs)
|
|
|
outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
|
|
|
|
|
|
self.forward_schema = (self.args_schema, self.kwargs_schema) # inputs for forward
|
|
@@ -87,22 +76,17 @@ class ExpertBackend:
|
|
|
self.forward_pool = TaskPool(self.forward, name=f"{self.name}_forward", **kwargs)
|
|
|
self.backward_pool = TaskPool(self.backward, name=f"{self.name}_backward", **kwargs)
|
|
|
|
|
|
- self.update_count = 0
|
|
|
- self.examples_processed = 0
|
|
|
-
|
|
|
def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
"""
|
|
|
Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually;
|
|
|
- To submit a request for asynchronous processing, please use ``ExpertBackend.forward_pool.submit_task``.
|
|
|
+ To submit a request for asynchronous processing, please use ``ModuleBackend.forward_pool.submit_task``.
|
|
|
+
|
|
|
+ .. warning: if the underlying module performs non-gradient updates (e.g. batchnorm), it will be updated twice:
|
|
|
+ once during forward pass, and again during backward. This behavior is similar to gradient checkpointing.
|
|
|
|
|
|
Subclassing:
|
|
|
This method receives a sequence of torch tensors following ``nested_flatten(self.forward_schema)``;
|
|
|
-
|
|
|
It should return gradients w.r.t. inputs that follow ``nested_flatten(self.outputs_schema)``;
|
|
|
-
|
|
|
- .. todo we handle layer states (e.g. batchnorm stats) incorrectly, updating them twice.
|
|
|
- .. For now, either register all buffers as outputs or avoid stateful experts
|
|
|
-
|
|
|
"""
|
|
|
args, kwargs = nested_pack(inputs, structure=self.forward_schema)
|
|
|
|
|
@@ -110,7 +94,7 @@ class ExpertBackend:
|
|
|
raise RuntimeError("Batch should contain more than 0 samples")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- outputs = self.expert(*args, **kwargs)
|
|
|
+ outputs = self.module(*args, **kwargs)
|
|
|
|
|
|
# Note: TaskPool requires function to accept and return a flat tuple of values, we pack/unpack it on client side
|
|
|
return tuple(nested_flatten(outputs))
|
|
@@ -118,7 +102,7 @@ class ExpertBackend:
|
|
|
def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
"""
|
|
|
Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
|
|
|
- To submit a request for asynchronous processing, please use ``ExpertBackend.backward_pool.submit_task``.
|
|
|
+ To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``.
|
|
|
|
|
|
Subclassing:
|
|
|
This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
|
|
@@ -128,9 +112,7 @@ class ExpertBackend:
|
|
|
Runtime doesn't guarantee that backward will be performed in the same order and for the same data
|
|
|
as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
|
|
|
|
|
|
- .. todo correct state handling (see forward)
|
|
|
-
|
|
|
- Please make sure to call ``ExpertBackend.apply_gradients`` here, otherwise the expert will not train
|
|
|
+ Please make sure to call ``ModuleBackend.on_backward`` after each call to backward
|
|
|
"""
|
|
|
(args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
|
|
|
|
|
@@ -148,7 +130,7 @@ class ExpertBackend:
|
|
|
|
|
|
batch_size = args[0].size(0)
|
|
|
|
|
|
- outputs = self.expert(*args, **kwargs)
|
|
|
+ outputs = self.module(*args, **kwargs)
|
|
|
assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
|
|
|
|
|
|
outputs_flat = tuple(nested_flatten(outputs))
|
|
@@ -163,65 +145,45 @@ class ExpertBackend:
|
|
|
torch.autograd.backward(
|
|
|
outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
|
|
|
)
|
|
|
- self.apply_gradients(batch_size)
|
|
|
+ self.on_backward(batch_size)
|
|
|
|
|
|
return tuple(
|
|
|
x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
|
|
|
)
|
|
|
|
|
|
- def apply_gradients(self, batch_size) -> None:
|
|
|
+ def on_backward(self, batch_size: int) -> None:
|
|
|
"""
|
|
|
- Train the expert for one step. This method is called by ``ExpertBackend.backward`` after computing gradients.
|
|
|
+ Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients.
|
|
|
"""
|
|
|
- if self.clip_grad_norm is not None:
|
|
|
- torch.nn.utils.clip_grad_norm_(self.expert.parameters(), self.clip_grad_norm)
|
|
|
-
|
|
|
- self.optimizer.step()
|
|
|
- self.optimizer.zero_grad()
|
|
|
+ if self.optimizer is not None:
|
|
|
+ self.optimizer.step()
|
|
|
+ self.optimizer.zero_grad()
|
|
|
|
|
|
if self.scheduler is not None:
|
|
|
self.scheduler.step()
|
|
|
|
|
|
- self.update_count += 1
|
|
|
- self.examples_processed += batch_size
|
|
|
-
|
|
|
- def get_stats(self) -> Dict:
|
|
|
- """
|
|
|
- Return current expert training statistics (number of updates, number of processed examples after
|
|
|
- last optimizer step)
|
|
|
- """
|
|
|
- return {"updates": self.update_count, "examples_processed": self.examples_processed}
|
|
|
-
|
|
|
- def get_full_state(self) -> Dict:
|
|
|
- """
|
|
|
- Return the current state of the expert (including batch processing statistics)
|
|
|
- """
|
|
|
- full_state = {
|
|
|
- "stats": self.get_stats(),
|
|
|
- "model": self.expert.state_dict(),
|
|
|
- "optimizer": self.optimizer.state_dict(),
|
|
|
- "scheduler": {} if self.scheduler is None else self.scheduler.state_dict(),
|
|
|
- }
|
|
|
+ def state_dict(self) -> Dict:
|
|
|
+ """Return the current state of the module, optimizer, and scheduler"""
|
|
|
+ full_state = dict(module=self.module.state_dict())
|
|
|
+ if self.optimizer is not None:
|
|
|
+ full_state["optimizer"] = self.optimizer.state_dict()
|
|
|
+ if self.scheduler is not None:
|
|
|
+ full_state["scheduler"] = self.scheduler.state_dict()
|
|
|
return full_state
|
|
|
|
|
|
- def load_full_state(self, state_dict: Dict):
|
|
|
- if "stats" in state_dict:
|
|
|
- self.update_count = state_dict["stats"]["updates"]
|
|
|
- self.examples_processed = state_dict["stats"]["examples_processed"]
|
|
|
- else:
|
|
|
- logger.warning(f"Batch processing stats missing for expert {self.name}")
|
|
|
-
|
|
|
- self.expert.load_state_dict(state_dict["model"])
|
|
|
+ def load_state_dict(self, state_dict: Dict):
|
|
|
+ self.module.load_state_dict(state_dict["module"])
|
|
|
+ if self.optimizer is not None:
|
|
|
+ if "optimizer" in state_dict:
|
|
|
+ self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
|
+ else:
|
|
|
+ logger.warning(f"Optimizer state missing for {self.name}")
|
|
|
|
|
|
- if "optimizer" in state_dict:
|
|
|
- self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
|
- else:
|
|
|
- logger.warning(f"Optimizer state missing for expert {self.name}")
|
|
|
-
|
|
|
- if self.scheduler is not None and "scheduler" in state_dict:
|
|
|
- self.scheduler.load_state_dict(state_dict["scheduler"])
|
|
|
- else:
|
|
|
- logger.warning(f"Learning rate scheduler state missing for expert {self.name}")
|
|
|
+ if self.scheduler is not None:
|
|
|
+ if "scheduler" in state_dict:
|
|
|
+ self.scheduler.load_state_dict(state_dict["scheduler"])
|
|
|
+ else:
|
|
|
+ logger.warning(f"Learning rate scheduler state missing for {self.name}")
|
|
|
|
|
|
def get_info(self) -> Dict[str, Any]:
|
|
|
"""Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
|