|
@@ -29,7 +29,6 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
|
|
|
|
|
|
sess.position = 2
|
|
|
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
|
|
|
- secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
|
|
|
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
|
|
|
|
|
|
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
|