|
@@ -13,9 +13,10 @@ import torch.mps
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
-from petals.server.block_utils import resolve_block_dtype
|
|
|
+from petals.server.block_utils import get_model_block, resolve_block_dtype
|
|
|
from petals.utils.convert_block import QuantType, convert_block
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
+from petals.utils.misc import DUMMY_KEY_PAST
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -201,18 +202,25 @@ def measure_compute_rps(
|
|
|
if not tensor_parallel_devices:
|
|
|
tensor_parallel_devices = (device,)
|
|
|
with torch.inference_mode():
|
|
|
- block = config.block_class(config).to(dtype)
|
|
|
+ block = get_model_block(config)
|
|
|
+ block = block.to(dtype)
|
|
|
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
|
|
|
|
|
- cache = None
|
|
|
+ cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
|
|
|
elapsed = 0
|
|
|
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
|
|
|
- _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
|
|
|
+
|
|
|
+ # Skip the 1st step to exclude the initialization time
|
|
|
+ def step(cache_):
|
|
|
+ outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
|
|
|
+ return outputs[1] if inference else None
|
|
|
+
|
|
|
+ cache = step(cache)
|
|
|
synchronize(device)
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
for _ in range(n_steps):
|
|
|
- _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
|
|
|
+ cache = step(cache)
|
|
|
synchronize(device)
|
|
|
elapsed = time.perf_counter() - start_time
|
|
|
device_rps = n_steps * n_tokens / elapsed
|