瀏覽代碼

Fix and improve preprocessing

Aleksandr Borzunov 3 年之前
父節點
當前提交
17da8639c2
共有 1 個文件被更改,包括 26 次插入7 次删除
  1. 26 7
      data.py

+ 26 - 7
data.py

@@ -1,3 +1,4 @@
+import itertools
 from typing import Optional
 
 import hivemind
@@ -7,6 +8,29 @@ from datasets import load_dataset
 logger = hivemind.get_logger(__name__)
 
 
+def preprocess_batch(batch, tokenizer, max_sequence_length: int):
+    mask = [
+        (
+            caption is not None and len(caption) >= 3 and
+            nsfw == 'UNLIKELY' and
+            orig_width > 0 and orig_height > 0 and
+            max(orig_height / orig_width, orig_width / orig_height) <= 2
+        ) for caption, nsfw, orig_width, orig_height in
+        zip(batch['caption'], batch['NSFW'], batch['original_width'], batch['original_height'])
+    ]
+    logger.debug(f'{np.mean(mask) * 100:.1f}% of examples left after filtering')
+
+    if any(mask):
+        result = tokenizer(list(itertools.compress(batch['caption'], mask)),
+                           truncation=True, max_length=max_sequence_length)
+    else:
+        # This branch is necessary because tokenizer([]) raises IndexError
+        result = {'input_ids': [], 'attention_mask': []}
+    result['image'] = [np.frombuffer(encoded, np.int16).astype(np.int64)
+                       for encoded in itertools.compress(batch['code'], mask)]
+    return result
+
+
 def make_dataset(
     tokenizer,
     *,
@@ -17,12 +41,7 @@ def make_dataset(
 ):
     ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
     ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
-    ds = ds.map(lambda item: dict(
-        tokenizer(
-            [caption if caption is not None else '' for caption in item['caption']],
-            truncation=True, max_length=max_sequence_length,
-        ),
-        image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
-    ), batched=True, batch_size=preprocessing_batch_size)
+    ds = ds.map(lambda batch: preprocess_batch(batch, tokenizer, max_sequence_length),
+                batched=True, batch_size=preprocessing_batch_size)
     ds = ds.with_format('torch')
     return ds