Переглянути джерело

Fix output shape when resuming generation (#211)

Before this PR, `model.generate()` returned one excess token when resuming generation with an existing (the last token of the previous session, `session.last_token_id`). This is an unexpected behavior not convenient for the downstream apps, so this PR changes it until it's too late.
Alexander Borzunov 2 роки тому
батько
коміт
6ba63c6cc8
3 змінених файлів з 10 додано та 5 видалено
  1. 1 1
      setup.cfg
  2. 7 4
      src/petals/client/remote_generation.py
  3. 2 0
      src/petals/server/throughput.py

+ 1 - 1
setup.cfg

@@ -42,7 +42,7 @@ install_requires =
     humanfriendly
     async-timeout>=4.0.2
     cpufeature>=0.2.0
-    packaging>=23.0
+    packaging>=20.9
 
 [options.extras_require]
 dev =

+ 7 - 4
src/petals/client/remote_generation.py

@@ -104,17 +104,18 @@ class RemoteGenerationMixin:
         elif max_length is None and max_new_tokens is not None:
             max_length = prefix_length + max_new_tokens
 
-        if num_beams > 1 and session is not None:
+        resuming_session = session is not None and session.last_token_id is not None
+        if num_beams > 1 and resuming_session:
             raise NotImplementedError(
-                "Reusing inference session in .generate() along with beam search is not supported yet"
+                "Resuming inference session in .generate() along with beam search is not supported yet"
             )
 
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
-            if session is not None and session.last_token_id is not None:
+            if resuming_session:
                 inputs = torch.cat([session.last_token_id, inputs], dim=1)
         else:
-            if session is not None and session.last_token_id is not None:
+            if resuming_session:
                 inputs = session.last_token_id
             else:
                 assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
@@ -207,6 +208,8 @@ class RemoteGenerationMixin:
 
         outputs = torch.cat(outputs, dim=-1)
 
+        if resuming_session:
+            outputs = outputs[:, 1:]
         if num_beams > 1:
             pre_return_idx = [
                 torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)

+ 2 - 0
src/petals/server/throughput.py

@@ -123,6 +123,8 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
 
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
+    if network_rps == 0:
+        raise ValueError("speedtest has returned network_rps == 0")
 
     logger.info(
         f"Network throughput: "