tokenize_wikitext103.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/env python
  2. """ This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets """
  3. import random
  4. from functools import partial
  5. import nltk
  6. from datasets import load_dataset
  7. from transformers import AlbertTokenizerFast
  8. COLUMN_NAMES = ("attention_mask", "input_ids", "sentence_order_label", "special_tokens_mask", "token_type_ids")
  9. def create_instances_from_document(tokenizer, document, max_seq_length):
  10. """
  11. Creates training instances from a single document.
  12. Reuses code from the original ALBERT implementation (Google AI, 2018)
  13. https://github.com/google-research/albert/blob/master/create_pretraining_data.py#L267
  14. """
  15. # We DON'T just concatenate all of the tokens from a document into a long
  16. # sequence and choose an arbitrary split point because this would make the
  17. # next sentence prediction task too easy. Instead, we split the input into
  18. # segments "A" and "B" based on the actual "sentences" provided by the user
  19. # input.
  20. instances = []
  21. current_chunk = []
  22. current_length = 0
  23. segmented_sents = list(nltk.sent_tokenize(document))
  24. for i, sent in enumerate(segmented_sents):
  25. current_chunk.append(sent)
  26. current_length += len(tokenizer.tokenize(sent))
  27. if i == len(segmented_sents) - 1 or current_length >= max_seq_length:
  28. if len(current_chunk) > 1:
  29. # `a_end` is how many segments from `current_chunk` go into the `A`
  30. # (first) sentence.
  31. a_end = random.randint(1, len(current_chunk) - 1)
  32. tokens_a = []
  33. for j in range(a_end):
  34. tokens_a.append(current_chunk[j])
  35. tokens_b = []
  36. for j in range(a_end, len(current_chunk)):
  37. tokens_b.append(current_chunk[j])
  38. if random.random() < 0.5:
  39. # Random next
  40. is_random_next = True
  41. # Note(mingdachen): in this case, we just swap tokens_a and tokens_b
  42. tokens_a, tokens_b = tokens_b, tokens_a
  43. else:
  44. # Actual next
  45. is_random_next = False
  46. assert len(tokens_a) >= 1
  47. assert len(tokens_b) >= 1
  48. instance = tokenizer(
  49. " ".join(tokens_a),
  50. " ".join(tokens_b),
  51. truncation="longest_first",
  52. max_length=max_seq_length,
  53. # We use this option because DataCollatorForLanguageModeling
  54. # is more efficient when it receives the `special_tokens_mask`.
  55. return_special_tokens_mask=True,
  56. )
  57. assert len(instance["input_ids"]) <= max_seq_length
  58. instance["sentence_order_label"] = 1 if is_random_next else 0
  59. instances.append(instance)
  60. current_chunk = []
  61. current_length = 0
  62. return instances
  63. def tokenize_function(tokenizer, examples):
  64. # Remove empty texts
  65. texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
  66. new_examples = {col: [] for col in COLUMN_NAMES}
  67. for text in texts:
  68. instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
  69. for instance in instances:
  70. for key, value in instance.items():
  71. new_examples[key].append(value)
  72. return new_examples
  73. if __name__ == "__main__":
  74. random.seed(0)
  75. nltk.download("punkt")
  76. tokenizer = AlbertTokenizerFast.from_pretrained("albert-large-v2")
  77. wikitext = load_dataset("wikitext", "wikitext-103-v1", cache_dir="./data/cache")
  78. tokenized_datasets = wikitext.map(
  79. partial(tokenize_function, tokenizer),
  80. batched=True,
  81. num_proc=8,
  82. remove_columns=["text"],
  83. )
  84. tokenized_datasets.save_to_disk("./data/albert_tokenized_wikitext")
  85. tokenizer.save_pretrained("./data/tokenizer")