data.py 799 B

12345678910111213141516171819202122232425
  1. from typing import Optional
  2. import hivemind
  3. import numpy as np
  4. from datasets import load_dataset
  5. logger = hivemind.get_logger(__name__)
  6. def make_dataset(
  7. tokenizer,
  8. *,
  9. shuffle_buffer_size: int = 10 ** 4,
  10. shuffle_seed: Optional[int],
  11. preprocessing_batch_size: int = 256,
  12. max_sequence_length: int,
  13. ):
  14. ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
  15. ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
  16. ds = ds.map(lambda item: dict(
  17. tokenizer(item['caption'], truncation=True, max_length=max_sequence_length),
  18. image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
  19. ), batched=True, batch_size=preprocessing_batch_size)
  20. ds = ds.with_format('torch')
  21. return ds