|
@@ -18,7 +18,10 @@ def make_dataset(
|
|
|
ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
|
|
|
ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
|
|
|
ds = ds.map(lambda item: dict(
|
|
|
- tokenizer(item['caption'], truncation=True, max_length=max_sequence_length),
|
|
|
+ tokenizer(
|
|
|
+ [caption if caption is not None else '' for caption in item['caption']],
|
|
|
+ truncation=True, max_length=max_sequence_length,
|
|
|
+ ),
|
|
|
image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
|
|
|
), batched=True, batch_size=preprocessing_batch_size)
|
|
|
ds = ds.with_format('torch')
|