|
@@ -53,13 +53,22 @@ class TransformerBackend(ModuleBackend):
|
|
|
max_batch_size = self.forward_pool.max_batch_size
|
|
|
device = self.module.devices[self.module.output_device_index]
|
|
|
self.inference_pool = PrioritizedTaskPool(
|
|
|
- self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
|
|
+ lambda args, kwargs: self.inference_step(*args, **kwargs),
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ device=device,
|
|
|
+ name=f"{self.name}_inference",
|
|
|
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
|
|
self.forward_pool = PrioritizedTaskPool(
|
|
|
- self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
|
|
+ lambda args, kwargs: self.forward(*args, **kwargs),
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ device=device,
|
|
|
+ name=f"{self.name}_forward",
|
|
|
)
|
|
|
self.backward_pool = PrioritizedTaskPool(
|
|
|
- self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
|
|
|
+ lambda args, kwargs: self.backward(*args, **kwargs),
|
|
|
+ max_batch_size=max_batch_size,
|
|
|
+ device=device,
|
|
|
+ name=f"{self.name}_backward",
|
|
|
)
|
|
|
|
|
|
self.dtype = backend_dtype
|
|
@@ -96,27 +105,25 @@ class TransformerBackend(ModuleBackend):
|
|
|
cache_tensors.extend((keys, values))
|
|
|
return cache_tensors
|
|
|
|
|
|
- def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]:
|
|
|
+ def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
|
|
|
with self._peft_module.using_adapter(active_adapter), torch.no_grad():
|
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
def backward(
|
|
|
- self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs
|
|
|
+ self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
|
|
|
with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
|
|
|
(outputs,) = self.module(*args, **kwargs)
|
|
|
assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape
|
|
|
torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False)
|
|
|
- return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None)
|
|
|
+ return nested_map(
|
|
|
+ lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)
|
|
|
+ )
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def inference_step(
|
|
|
- self,
|
|
|
- hidden_states: torch.Tensor,
|
|
|
- hypo_ids: torch.LongTensor,
|
|
|
- kwargs: Dict[str, torch.Tensor],
|
|
|
- inference_info: InferenceMetadata,
|
|
|
+ self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
seq_len = hidden_states.shape[1]
|
|
@@ -217,8 +224,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
|
|
|
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
|
|
|
assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values())
|
|
|
first_pool = next(iter(backends.values())).inference_pool
|
|
|
+ merged_inference_func = _MergedInferenceStep(backends)
|
|
|
merged_pool = PrioritizedTaskPool(
|
|
|
- _MergedInferenceStep(backends),
|
|
|
+ lambda args, kwargs: merged_inference_func(*args, **kwargs),
|
|
|
max_batch_size=first_pool.max_batch_size,
|
|
|
device=first_pool.device,
|
|
|
name=f"merged_inference",
|
|
@@ -237,9 +245,9 @@ class _MergedInferenceStep:
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
hypo_ids: torch.LongTensor,
|
|
|
- backend_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
inference_infos: Sequence[InferenceMetadata],
|
|
|
*optional_prompts: Optional[torch.Tensor],
|
|
|
+ backend_kwargs: Sequence[Dict[str, torch.Tensor]],
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
assert (
|
|
|
len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
|
|
@@ -248,6 +256,6 @@ class _MergedInferenceStep:
|
|
|
if optional_prompt is not None:
|
|
|
hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
|
|
|
(hidden_states,) = self.backends[inference_info.uid].inference_step(
|
|
|
- hidden_states, hypo_ids, kwargs, inference_info
|
|
|
+ hidden_states, hypo_ids, inference_info, **kwargs
|
|
|
)
|
|
|
return (hidden_states,)
|