Your Name 2 years ago
parent
commit
7dc1aa5151
1 changed files with 4 additions and 0 deletions
  1. 4 0
      src/petals/server/backend.py

+ 4 - 0
src/petals/server/backend.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import time
 from collections import Counter
 from collections import Counter
 from itertools import chain
 from itertools import chain
 from typing import Any, Dict, Optional, Sequence, Tuple, Union
 from typing import Any, Dict, Optional, Sequence, Tuple, Union
@@ -188,8 +189,11 @@ class _MergedInferenceStep:
         assert len(inference_infos) == len(
         assert len(inference_infos) == len(
             optional_prompts
             optional_prompts
         ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
         ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts"
+        t0 = time.perf_counter()
         for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
         for inference_info, optional_prompt in zip(inference_infos, optional_prompts):
             if optional_prompt is not None:
             if optional_prompt is not None:
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
                 hidden_states[:, : optional_prompt.shape[1]] += optional_prompt
             (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
             (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info)
+        torch.cuda.synchronize()
+        print(f"INFERENCE TIME: {time.perf_counter() - t0:.5f} s")
         return (hidden_states,)
         return (hidden_states,)