data.py 890 B

12345678910111213141516171819202122232425262728
  1. from typing import Optional
  2. import hivemind
  3. import numpy as np
  4. from datasets import load_dataset
  5. logger = hivemind.get_logger(__name__)
  6. def make_dataset(
  7. tokenizer,
  8. *,
  9. shuffle_buffer_size: int = 10 ** 4,
  10. shuffle_seed: Optional[int],
  11. preprocessing_batch_size: int = 256,
  12. max_sequence_length: int,
  13. ):
  14. ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
  15. ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
  16. ds = ds.map(lambda item: dict(
  17. tokenizer(
  18. [caption if caption is not None else '' for caption in item['caption']],
  19. truncation=True, max_length=max_sequence_length,
  20. ),
  21. image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
  22. ), batched=True, batch_size=preprocessing_batch_size)
  23. ds = ds.with_format('torch')
  24. return ds