|
@@ -35,7 +35,7 @@ async def run_rpc_forward(
|
|
|
active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
- args_structure: Any = None,
|
|
|
+ structure: Any,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
@@ -45,22 +45,19 @@ async def run_rpc_forward(
|
|
|
:param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
|
|
|
:returns: hidden states after the last layer [batch_size, seq_length, hid_size]
|
|
|
"""
|
|
|
- if args_structure is not None:
|
|
|
- # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
- flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
- hidden_states, prompts, *_ = flat_tensors
|
|
|
-
|
|
|
+ (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
|
|
|
dtype = requested_backends[0].dtype
|
|
|
# check parse input tensors and cast dtypes
|
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
assert hidden_states.ndim == 3
|
|
|
+ num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
prompts = [DUMMY] * len(requested_backends)
|
|
|
else:
|
|
|
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
|
|
|
# Run a chain of requested backends
|
|
|
- for backend, prompt in zip(requested_backends, prompts):
|
|
|
+ for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs):
|
|
|
if not is_dummy(prompt):
|
|
|
hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
|
|
@@ -70,8 +67,10 @@ async def run_rpc_forward(
|
|
|
)
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
hidden_states,
|
|
|
- active_adapter,
|
|
|
+ active_adapter=active_adapter,
|
|
|
+ **kwargs,
|
|
|
priority=priority,
|
|
|
+ size=num_tokens,
|
|
|
)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert (
|
|
@@ -87,15 +86,13 @@ async def run_rpc_backward(
|
|
|
active_adapter: str = "",
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int = 0,
|
|
|
- args_structure: Any = None,
|
|
|
+ structure: Any,
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
- if args_structure is not None:
|
|
|
- # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
- flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
- inputs, grad_outputs, prompts, *_ = flat_tensors
|
|
|
-
|
|
|
+ (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
|
- inputs = inputs.to(requested_backends[0].dtype)
|
|
|
+ assert hidden_states.ndim == 3
|
|
|
+ num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
|
|
|
+ hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
@@ -106,32 +103,36 @@ async def run_rpc_backward(
|
|
|
# Run a forward chain to collect intermediate inputs
|
|
|
# Note that we do not forward for the last module since we do not need its output
|
|
|
inter_inputs = []
|
|
|
- for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
+ for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs):
|
|
|
+ assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
- inputs[:, : prompt.shape[1]] += prompt
|
|
|
- inter_inputs.append(inputs)
|
|
|
+ hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
+ inter_inputs.append(hidden_states)
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
priority = prioritizer.prioritize(
|
|
|
- inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
+ hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
|
+ )
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(
|
|
|
+ hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens
|
|
|
)
|
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
|
|
|
|
|
- assert isinstance(inputs, torch.Tensor)
|
|
|
+ assert isinstance(hidden_states, torch.Tensor)
|
|
|
|
|
|
if not is_dummy(prompts[-1]):
|
|
|
- inputs[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
- inter_inputs.append(inputs)
|
|
|
+ hidden_states[:, : prompts[-1].shape[1]] += prompts[-1]
|
|
|
+ inter_inputs.append(hidden_states)
|
|
|
|
|
|
assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
|
|
|
grad_prompts_reversed = []
|
|
|
# Run a chain of requested backends
|
|
|
- for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
+ for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))):
|
|
|
assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
|
|
|
priority = prioritizer.prioritize(
|
|
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
|
)
|
|
|
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(
|
|
|
+ inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens
|
|
|
+ )
|
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
|
if not is_dummy(prompt):
|
|
@@ -152,7 +153,7 @@ async def iterate_rpc_inference(
|
|
|
prioritizer: TaskPrioritizerBase,
|
|
|
points: int,
|
|
|
quant_type: QuantType,
|
|
|
- args_structure: Any = None,
|
|
|
+ structure: Any = None,
|
|
|
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
|
|
|
assert len(cache_handles) == len(requested_backends)
|
|
|
|
|
@@ -161,12 +162,9 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
async for request, step_metadata in input_iterator:
|
|
|
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
|
|
|
- if args_structure is not None:
|
|
|
- # TODO: kwargs currently is unused, it can be used later for peft-like adaptation
|
|
|
- flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
|
|
|
-
|
|
|
- hidden_states, prompts, hypo_ids, *_ = flat_tensors
|
|
|
+ (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure)
|
|
|
batch_size, length_increment, _ = hidden_states.shape
|
|
|
+ num_tokens = batch_size * length_increment
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
@@ -209,13 +207,21 @@ async def iterate_rpc_inference(
|
|
|
for uid, handles in zip(requested_uids, cache_handles)
|
|
|
)
|
|
|
(hidden_states,) = await requested_backends[0].inference_pool.submit_task(
|
|
|
- hidden_states, hypo_ids, inference_infos, *prompts, priority=priority
|
|
|
+ hidden_states,
|
|
|
+ hypo_ids,
|
|
|
+ inference_infos,
|
|
|
+ *prompts,
|
|
|
+ backend_kwargs,
|
|
|
+ priority=priority,
|
|
|
+ size=num_tokens,
|
|
|
)
|
|
|
else:
|
|
|
- for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts):
|
|
|
+ for backend, uid, handles, prompt, kwargs in zip(
|
|
|
+ requested_backends, requested_uids, cache_handles, prompts, backend_kwargs
|
|
|
+ ):
|
|
|
inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
|
|
|
(hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
- hidden_states, hypo_ids, inference_infos, prompt, priority=priority
|
|
|
+ hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens
|
|
|
)
|
|
|
|
|
|
# serialize and send last layer outputs
|
|
@@ -228,3 +234,22 @@ async def iterate_rpc_inference(
|
|
|
|
|
|
# prepare for next step
|
|
|
prefix_length += length_increment
|
|
|
+
|
|
|
+
|
|
|
+def _check_inputs(
|
|
|
+ requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any
|
|
|
+):
|
|
|
+ if structure is not None:
|
|
|
+ args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure)
|
|
|
+ else:
|
|
|
+ args, *backend_kwargs = flat_tensors, {} # backward compatibility
|
|
|
+
|
|
|
+ if len(backend_kwargs) not in (1, len(requested_backends)):
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
|
|
|
+ f"(one for each block). Found {len(backend_kwargs)} instead."
|
|
|
+ )
|
|
|
+ if len(backend_kwargs) == 1:
|
|
|
+ backend_kwargs = (backend_kwargs,) * len(requested_backends)
|
|
|
+ assert len(backend_kwargs) == len(requested_backends)
|
|
|
+ return args, backend_kwargs
|