Browse Source

standardize: s/backend_kwargs/block_kwargs/g everywhere

Your Name 1 year ago
parent
commit
8eb1722f1e
2 changed files with 25 additions and 25 deletions
  1. 4 4
      src/petals/server/backend.py
  2. 21 21
      src/petals/server/block_functions.py

+ 4 - 4
src/petals/server/backend.py

@@ -242,12 +242,12 @@ class _MergedInferenceStep:
         hypo_ids: torch.LongTensor,
         hypo_ids: torch.LongTensor,
         inference_infos: Sequence[InferenceMetadata],
         inference_infos: Sequence[InferenceMetadata],
         *optional_prompts: Optional[torch.Tensor],
         *optional_prompts: Optional[torch.Tensor],
-        backend_kwargs: Sequence[Dict[str, torch.Tensor]],
+        block_kwargs: Sequence[Dict[str, torch.Tensor]],
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:
         assert (
         assert (
-            len(inference_infos) == len(optional_prompts) == len(backend_kwargs)
-        ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(backend_kwargs)} kwargs"
-        for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, backend_kwargs):
+            len(inference_infos) == len(optional_prompts) == len(block_kwargs)
+        ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(block_kwargs)} kwargs"
+        for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, block_kwargs):
             if optional_prompt is not None:
             if optional_prompt is not None:
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
             (hidden_states,) = self.backends[inference_info.uid].inference_step(
             (hidden_states,) = self.backends[inference_info.uid].inference_step(

+ 21 - 21
src/petals/server/block_functions.py

@@ -52,7 +52,7 @@ async def run_rpc_forward(
     """
     """
     requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
     requires_grad = any(tensor.requires_grad for tensor in flat_tensors)
     flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
     flat_tensors = tuple(tensor.detach() for tensor in flat_tensors)
-    (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
+    (hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure)
     dtype = requested_backends[0].dtype
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     hidden_states = hidden_states.to(dtype)
@@ -64,7 +64,7 @@ async def run_rpc_forward(
         prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
         prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
 
 
     # Run a chain of requested backends
     # Run a chain of requested backends
-    for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs):
+    for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs):
         if not is_dummy(prompt):
         if not is_dummy(prompt):
             hidden_states[:, : prompt.shape[1]] += prompt
             hidden_states[:, : prompt.shape[1]] += prompt
 
 
@@ -97,7 +97,7 @@ async def run_rpc_backward(
 ) -> Tuple[Sequence[torch.Tensor], Any]:
 ) -> Tuple[Sequence[torch.Tensor], Any]:
     """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
     """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests"""
     assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
     assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
-    ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs(
+    ((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs(
         requested_backends, flat_tensors, args_structure
         requested_backends, flat_tensors, args_structure
     )
     )
     input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
     input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad
@@ -115,7 +115,7 @@ async def run_rpc_backward(
     # Run a forward chain to collect intermediate inputs
     # Run a forward chain to collect intermediate inputs
     # Note that we do not forward for the last module since we do not need its output
     # Note that we do not forward for the last module since we do not need its output
     inter_inputs = []
     inter_inputs = []
-    for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs):
+    for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs):
         assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
         if not is_dummy(prompt):
         if not is_dummy(prompt):
             hidden_states[:, : prompt.shape[1]] += prompt
             hidden_states[:, : prompt.shape[1]] += prompt
@@ -135,11 +135,11 @@ async def run_rpc_backward(
 
 
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
     grad_prompts_reversed = []
     grad_prompts_reversed = []
-    grad_backend_kwargs_reversed = []
+    grad_block_kwargs_reversed = []
 
 
     # Run a chain of requested backends
     # Run a chain of requested backends
     for hidden_states, prompt, backend, kwargs in reversed(
     for hidden_states, prompt, backend, kwargs in reversed(
-        list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))
+        list(zip(inter_inputs, prompts, requested_backends, block_kwargs))
     ):
     ):
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
         assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
         hidden_states = hidden_states.detach().requires_grad_(True)
         hidden_states = hidden_states.detach().requires_grad_(True)
@@ -152,11 +152,11 @@ async def run_rpc_backward(
         assert isinstance(grad_outputs, torch.Tensor)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt) and prompts_requires_grad:
         if not is_dummy(prompt) and prompts_requires_grad:
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
-        grad_backend_kwargs_reversed.append(grad_kwargs)
+        grad_block_kwargs_reversed.append(grad_kwargs)
 
 
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
     grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
     grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts]
-    return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed))))
+    return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed))))
 
 
 
 
 async def iterate_rpc_inference(
 async def iterate_rpc_inference(
@@ -179,7 +179,7 @@ async def iterate_rpc_inference(
 
 
     async for request, step_metadata in input_iterator:
     async for request, step_metadata in input_iterator:
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
         flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
-        (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(
+        (hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs(
             requested_backends, flat_tensors, args_structure
             requested_backends, flat_tensors, args_structure
         )
         )
         batch_size, length_increment, _ = hidden_states.shape
         batch_size, length_increment, _ = hidden_states.shape
@@ -230,13 +230,13 @@ async def iterate_rpc_inference(
                     hypo_ids,
                     hypo_ids,
                     inference_infos,
                     inference_infos,
                     *prompts,
                     *prompts,
-                    backend_kwargs=backend_kwargs,
+                    block_kwargs=block_kwargs,
                     priority=priority,
                     priority=priority,
                     size=num_tokens,
                     size=num_tokens,
                 )
                 )
             else:
             else:
                 for backend, uid, handles, prompt, kwargs in zip(
                 for backend, uid, handles, prompt, kwargs in zip(
-                    requested_backends, requested_uids, cache_handles, prompts, backend_kwargs
+                    requested_backends, requested_uids, cache_handles, prompts, block_kwargs
                 ):
                 ):
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),)
                     (hidden_states,) = await backend.inference_pool.submit_task(
                     (hidden_states,) = await backend.inference_pool.submit_task(
@@ -244,7 +244,7 @@ async def iterate_rpc_inference(
                         hypo_ids,
                         hypo_ids,
                         inference_infos,
                         inference_infos,
                         prompt,
                         prompt,
-                        backend_kwargs=(kwargs,),
+                        block_kwargs=(kwargs,),
                         priority=priority,
                         priority=priority,
                         size=num_tokens,
                         size=num_tokens,
                     )
                     )
@@ -269,19 +269,19 @@ def _check_inputs(
             hidden_states, grad_outputs, prompts = flat_tensors
             hidden_states, grad_outputs, prompts = flat_tensors
             flat_tensors = grad_outputs, hidden_states, prompts
             flat_tensors = grad_outputs, hidden_states, prompts
     if args_structure is not None:
     if args_structure is not None:
-        args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
+        args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
     else:
     else:
-        args, *backend_kwargs = flat_tensors, {}  # backward compatibility for grad structure, remove at 2.2
+        args, *block_kwargs = flat_tensors, {}  # backward compatibility for grad structure, remove at 2.2
 
 
-    if len(backend_kwargs) not in (1, len(requested_backends)):
+    if len(block_kwargs) not in (1, len(requested_backends)):
         raise RuntimeError(
         raise RuntimeError(
             f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
             f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts "
-            f"(one for each block). Found {len(backend_kwargs)} instead."
+            f"(one for each block). Found {len(block_kwargs)} instead."
         )
         )
-    if len(backend_kwargs) == 1:
-        backend_kwargs = backend_kwargs * len(requested_backends)
-    assert len(backend_kwargs) == len(requested_backends)
-    for i, kwargs in enumerate(backend_kwargs):
+    if len(block_kwargs) == 1:
+        block_kwargs = block_kwargs * len(requested_backends)
+    assert len(block_kwargs) == len(requested_backends)
+    for i, kwargs in enumerate(block_kwargs):
         if not isinstance(kwargs, dict):
         if not isinstance(kwargs, dict):
             raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
             raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}")
-    return args, backend_kwargs
+    return args, block_kwargs