tokenize_wikitext103.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python
  2. """ This script builds a pre-tokenized compressed representation of wikitext103 using huggingface/datasets """
  3. import random
  4. from functools import partial
  5. from multiprocessing import cpu_count
  6. import nltk
  7. from datasets import load_dataset
  8. from transformers import AlbertTokenizerFast
  9. COLUMN_NAMES = ('attention_mask', 'input_ids', 'sentence_order_label', 'special_tokens_mask', 'token_type_ids')
  10. def create_instances_from_document(tokenizer, document, max_seq_length):
  11. """Creates `TrainingInstance`s for a single document."""
  12. # We DON'T just concatenate all of the tokens from a document into a long
  13. # sequence and choose an arbitrary split point because this would make the
  14. # next sentence prediction task too easy. Instead, we split the input into
  15. # segments "A" and "B" based on the actual "sentences" provided by the user
  16. # input.
  17. instances = []
  18. current_chunk = []
  19. current_length = 0
  20. segmented_sents = list(nltk.sent_tokenize(document))
  21. for i, sent in enumerate(segmented_sents):
  22. current_chunk.append(sent)
  23. current_length += len(tokenizer.tokenize(sent))
  24. if i == len(segmented_sents) - 1 or current_length >= max_seq_length:
  25. if len(current_chunk) > 1:
  26. # `a_end` is how many segments from `current_chunk` go into the `A`
  27. # (first) sentence.
  28. a_end = random.randint(1, len(current_chunk) - 1)
  29. tokens_a = []
  30. for j in range(a_end):
  31. tokens_a.append(current_chunk[j])
  32. tokens_b = []
  33. for j in range(a_end, len(current_chunk)):
  34. tokens_b.append(current_chunk[j])
  35. if random.random() < 0.5:
  36. # Random next
  37. is_random_next = True
  38. # Note(mingdachen): in this case, we just swap tokens_a and tokens_b
  39. tokens_a, tokens_b = tokens_b, tokens_a
  40. else:
  41. # Actual next
  42. is_random_next = False
  43. assert len(tokens_a) >= 1
  44. assert len(tokens_b) >= 1
  45. instance = tokenizer(
  46. ' '.join(tokens_a),
  47. ' '.join(tokens_b),
  48. truncation='longest_first',
  49. max_length=max_seq_length,
  50. # We use this option because DataCollatorForLanguageModeling
  51. # is more efficient when it receives the `special_tokens_mask`.
  52. return_special_tokens_mask=True,
  53. )
  54. assert len(instance['input_ids']) <= max_seq_length
  55. instance["sentence_order_label"] = 1 if is_random_next else 0
  56. instances.append(instance)
  57. current_chunk = []
  58. current_length = 0
  59. return instances
  60. def tokenize_function(tokenizer, examples):
  61. # Remove empty texts
  62. texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace())
  63. new_examples = {col: [] for col in COLUMN_NAMES}
  64. for text in texts:
  65. instances = create_instances_from_document(tokenizer, text, max_seq_length=512)
  66. for instance in instances:
  67. for key, value in instance.items():
  68. new_examples[key].append(value)
  69. return new_examples
  70. if __name__ == '__main__':
  71. random.seed(0)
  72. nltk.download('punkt')
  73. tokenizer = AlbertTokenizerFast.from_pretrained('albert-large-v2')
  74. wikitext = load_dataset('wikitext', 'wikitext-103-v1', cache_dir='./data/cache')
  75. tokenized_datasets = wikitext.map(
  76. partial(tokenize_function, tokenizer),
  77. batched=True,
  78. num_proc=8,
  79. remove_columns=["text"],
  80. )
  81. tokenized_datasets.save_to_disk('./data/albert_tokenized_wikitext')
  82. tokenizer.save_pretrained('./data/tokenizer')