|
@@ -1,3 +1,4 @@
|
|
|
|
+import itertools
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
@@ -7,6 +8,29 @@ from datasets import load_dataset
|
|
logger = hivemind.get_logger(__name__)
|
|
logger = hivemind.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
+def preprocess_batch(batch, tokenizer, max_sequence_length: int):
|
|
|
|
+ mask = [
|
|
|
|
+ (
|
|
|
|
+ caption is not None and len(caption) >= 3 and
|
|
|
|
+ nsfw == 'UNLIKELY' and
|
|
|
|
+ orig_width > 0 and orig_height > 0 and
|
|
|
|
+ max(orig_height / orig_width, orig_width / orig_height) <= 2
|
|
|
|
+ ) for caption, nsfw, orig_width, orig_height in
|
|
|
|
+ zip(batch['caption'], batch['NSFW'], batch['original_width'], batch['original_height'])
|
|
|
|
+ ]
|
|
|
|
+ logger.debug(f'{np.mean(mask) * 100:.1f}% of examples left after filtering')
|
|
|
|
+
|
|
|
|
+ if any(mask):
|
|
|
|
+ result = tokenizer(list(itertools.compress(batch['caption'], mask)),
|
|
|
|
+ truncation=True, max_length=max_sequence_length)
|
|
|
|
+ else:
|
|
|
|
+ # This branch is necessary because tokenizer([]) raises IndexError
|
|
|
|
+ result = {'input_ids': [], 'attention_mask': []}
|
|
|
|
+ result['image'] = [np.frombuffer(encoded, np.int16).astype(np.int64)
|
|
|
|
+ for encoded in itertools.compress(batch['code'], mask)]
|
|
|
|
+ return result
|
|
|
|
+
|
|
|
|
+
|
|
def make_dataset(
|
|
def make_dataset(
|
|
tokenizer,
|
|
tokenizer,
|
|
*,
|
|
*,
|
|
@@ -17,12 +41,7 @@ def make_dataset(
|
|
):
|
|
):
|
|
ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
|
|
ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
|
|
ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
|
|
ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
|
|
- ds = ds.map(lambda item: dict(
|
|
|
|
- tokenizer(
|
|
|
|
- [caption if caption is not None else '' for caption in item['caption']],
|
|
|
|
- truncation=True, max_length=max_sequence_length,
|
|
|
|
- ),
|
|
|
|
- image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
|
|
|
|
- ), batched=True, batch_size=preprocessing_batch_size)
|
|
|
|
|
|
+ ds = ds.map(lambda batch: preprocess_batch(batch, tokenizer, max_sequence_length),
|
|
|
|
+ batched=True, batch_size=preprocessing_batch_size)
|
|
ds = ds.with_format('torch')
|
|
ds = ds.with_format('torch')
|
|
return ds
|
|
return ds
|