data.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import itertools
  2. from typing import Optional
  3. import hivemind
  4. import numpy as np
  5. from datasets import load_dataset
  6. logger = hivemind.get_logger(__name__)
  7. def preprocess_batch(batch, tokenizer, max_sequence_length: int):
  8. mask = [
  9. (
  10. caption is not None and len(caption) >= 3 and
  11. nsfw == 'UNLIKELY' and
  12. orig_width > 0 and orig_height > 0 and
  13. max(orig_height / orig_width, orig_width / orig_height) <= 2
  14. ) for caption, nsfw, orig_width, orig_height in
  15. zip(batch['caption'], batch['NSFW'], batch['original_width'], batch['original_height'])
  16. ]
  17. logger.debug(f'{np.mean(mask) * 100:.1f}% of examples left after filtering')
  18. if any(mask):
  19. result = tokenizer(list(itertools.compress(batch['caption'], mask)),
  20. add_special_tokens=False, max_length=max_sequence_length, truncation=True)
  21. else:
  22. # This branch is necessary because tokenizer([]) raises IndexError
  23. result = {'input_ids': [], 'attention_mask': []}
  24. result['image'] = [np.frombuffer(encoded, np.int16).astype(np.int64)
  25. for encoded in itertools.compress(batch['code'], mask)]
  26. return result
  27. def make_dataset(
  28. tokenizer,
  29. *,
  30. shuffle_buffer_size: int = 8192,
  31. shuffle_seed: Optional[int],
  32. preprocessing_batch_size: int = 256,
  33. max_sequence_length: int,
  34. ):
  35. ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
  36. ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
  37. ds = ds.map(lambda batch: preprocess_batch(batch, tokenizer, max_sequence_length),
  38. batched=True, batch_size=preprocessing_batch_size)
  39. ds = ds.with_format('torch')
  40. return ds