|
@@ -104,17 +104,18 @@ class RemoteGenerationMixin:
|
|
elif max_length is None and max_new_tokens is not None:
|
|
elif max_length is None and max_new_tokens is not None:
|
|
max_length = prefix_length + max_new_tokens
|
|
max_length = prefix_length + max_new_tokens
|
|
|
|
|
|
- if num_beams > 1 and session is not None:
|
|
|
|
|
|
+ resuming_session = session is not None and session.last_token_id is not None
|
|
|
|
+ if num_beams > 1 and resuming_session:
|
|
raise NotImplementedError(
|
|
raise NotImplementedError(
|
|
- "Reusing inference session in .generate() along with beam search is not supported yet"
|
|
|
|
|
|
+ "Resuming inference session in .generate() along with beam search is not supported yet"
|
|
)
|
|
)
|
|
|
|
|
|
if inputs is not None:
|
|
if inputs is not None:
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
|
- if session is not None and session.last_token_id is not None:
|
|
|
|
|
|
+ if resuming_session:
|
|
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
|
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
|
else:
|
|
else:
|
|
- if session is not None and session.last_token_id is not None:
|
|
|
|
|
|
+ if resuming_session:
|
|
inputs = session.last_token_id
|
|
inputs = session.last_token_id
|
|
else:
|
|
else:
|
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
@@ -207,6 +208,8 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
outputs = torch.cat(outputs, dim=-1)
|
|
outputs = torch.cat(outputs, dim=-1)
|
|
|
|
|
|
|
|
+ if resuming_session:
|
|
|
|
+ outputs = outputs[:, 1:]
|
|
if num_beams > 1:
|
|
if num_beams > 1:
|
|
pre_return_idx = [
|
|
pre_return_idx = [
|
|
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|
|
torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
|