run_inference.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python3
  2. import argparse
  3. import os
  4. import pickle
  5. from collections import OrderedDict
  6. from datetime import datetime
  7. from itertools import cycle, islice
  8. import clip
  9. import numpy as np
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from PIL import Image
  14. from einops import rearrange
  15. # Note: Use dalle_pytorch >= 1.4.2 for this script (newer than in the rest of the repo)
  16. from dalle_pytorch import DALLE
  17. from dalle_pytorch.vae import VQGanVAE
  18. from transformers import T5TokenizerFast
  19. from tqdm import tqdm
  20. torch.set_grad_enabled(False)
  21. class VQGanParams(VQGanVAE):
  22. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  23. nn.Module.__init__(self)
  24. self.num_layers = num_layers
  25. self.image_size = image_size
  26. self.num_tokens = num_tokens
  27. self.is_gumbel = is_gumbel
  28. class ModelWrapper(nn.Module):
  29. def __init__(self, model):
  30. super().__init__()
  31. self.model = model
  32. def forward(self, input_ids, attention_mask, image):
  33. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  34. return {'loss': loss}
  35. def make_model():
  36. tokenizer = T5TokenizerFast.from_pretrained('t5-small')
  37. tokenizer.pad_token = tokenizer.eos_token
  38. depth = 64
  39. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  40. attn_types.append('conv_like')
  41. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  42. shared_layer_ids.append('w_conv')
  43. dalle = DALLE(
  44. vae=VQGanParams(),
  45. num_text_tokens=tokenizer.vocab_size,
  46. text_seq_len=256,
  47. dim=1024,
  48. depth=depth,
  49. heads=16,
  50. dim_head=64,
  51. attn_types=attn_types,
  52. ff_dropout=0,
  53. attn_dropout=0,
  54. shared_attn_ids=shared_layer_ids,
  55. shared_ff_ids=shared_layer_ids,
  56. rotary_emb=True,
  57. reversible=True,
  58. share_input_output_emb=True,
  59. optimize_for_inference=True,
  60. )
  61. model = ModelWrapper(dalle)
  62. return tokenizer, model
  63. def generate(query, *, tokenizer, model,
  64. batch_size, n_iters, temperature, top_k, top_p):
  65. input_ids = torch.tensor(tokenizer(query, add_special_tokens=False, max_length=256, truncation=True)['input_ids'])
  66. input_ids = F.pad(input_ids, (0, 256 - len(input_ids)), value=1)
  67. input_ids = input_ids.repeat(batch_size, 1)
  68. input_ids = input_ids.cuda()
  69. result = []
  70. for _ in tqdm(range(n_iters), desc=query, leave=False):
  71. output = model.model.generate_images(
  72. input_ids, temperature=temperature, top_k=top_k, top_p=top_p, use_cache=True)
  73. output = rearrange(output, 'b c h w -> b h w c').cpu().numpy()
  74. result.extend(output)
  75. return result
  76. def main():
  77. parser = argparse.ArgumentParser()
  78. parser.add_argument('--queries', type=str, help='List of queries (*.txt, newline-separated)')
  79. parser.add_argument('--temperature', type=float, help='Sampling temperature')
  80. parser.add_argument('--top-k', type=int, default=0)
  81. parser.add_argument('--top-p', type=float, default=1.0)
  82. parser.add_argument('--model', type=str, help='DALL-E checkpoint (*.pt)')
  83. parser.add_argument('--vqgan', type=str, help='VQGAN checkpoint (*.ckpt)')
  84. parser.add_argument('--vqgan-config', type=str, help='VQGAN config (*.yaml)')
  85. parser.add_argument('--output-dir', type=str, help='Output directory')
  86. args = parser.parse_args()
  87. with open(args.queries) as f:
  88. queries = [line.rstrip() for line in f]
  89. queries = [item for item in queries if len(item) > 0]
  90. print(f'[*] Loaded {len(queries)} queries')
  91. tokenizer, model = make_model()
  92. print(f'[*] Model modification time: {datetime.fromtimestamp(os.stat(args.model).st_mtime)}')
  93. state_dict = torch.load(args.model)
  94. # The model version optimized for inference requires some renaming in state_dict
  95. state_dict = OrderedDict([(key.replace('net.fn.fn', 'net.fn.fn.fn').replace('to_qkv', 'fn.to_qkv').replace('to_out', 'fn.to_out'), value)
  96. for key, value in state_dict.items()])
  97. ok = model.load_state_dict(state_dict)
  98. print(f'[*] Loaded model: {ok}')
  99. gan = VQGanVAE(args.vqgan, args.vqgan_config).cuda()
  100. model.model.vae = gan
  101. model = model.cuda()
  102. clip_model, clip_preprocess = clip.load("ViT-B/32", device='cuda')
  103. os.makedirs(args.output_dir, exist_ok=True)
  104. print(f'[*] Saving results to `{args.output_dir}`')
  105. for query in tqdm(queries):
  106. images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8,
  107. temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)
  108. images_for_clip = torch.cat([clip_preprocess(Image.fromarray((img * 255).astype(np.uint8))).unsqueeze(0).cuda() for img in images])
  109. text = clip.tokenize([query]).cuda()
  110. _, logits_per_text = clip_model(images_for_clip, text)
  111. clip_scores = logits_per_text[0].softmax(dim=-1).cpu().numpy()
  112. with open(os.path.join(args.output_dir, f'{query}.pickle'), 'wb') as f:
  113. outputs = {'query': query, 'temperature': args.temperature, 'images': images, 'clip_scores': clip_scores}
  114. pickle.dump(outputs, f)
  115. if __name__ == '__main__':
  116. main()