Aleksandr Borzunov преди 3 години
родител
ревизия
31215c85d3
променени са 1 файла, в които са добавени 4 реда и са изтрити 1 реда
  1. 4 1
      data.py

+ 4 - 1
data.py

@@ -18,7 +18,10 @@ def make_dataset(
     ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
     ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
     ds = ds.map(lambda item: dict(
-        tokenizer(item['caption'], truncation=True, max_length=max_sequence_length),
+        tokenizer(
+            [caption if caption is not None else '' for caption in item['caption']],
+            truncation=True, max_length=max_sequence_length,
+        ),
         image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
     ), batched=True, batch_size=preprocessing_batch_size)
     ds = ds.with_format('torch')