|
@@ -77,7 +77,7 @@ def make_model():
|
|
|
|
|
|
|
|
|
|
def generate(query, *, tokenizer, model,
|
|
def generate(query, *, tokenizer, model,
|
|
- batch_size, n_iters, temperature, filter_thres):
|
|
|
|
|
|
+ batch_size, n_iters, temperature, top_k, top_p):
|
|
input_ids = torch.tensor(tokenizer(query, add_special_tokens=False, max_length=256, truncation=True)['input_ids'])
|
|
input_ids = torch.tensor(tokenizer(query, add_special_tokens=False, max_length=256, truncation=True)['input_ids'])
|
|
input_ids = F.pad(input_ids, (0, 256 - len(input_ids)), value=1)
|
|
input_ids = F.pad(input_ids, (0, 256 - len(input_ids)), value=1)
|
|
input_ids = input_ids.repeat(batch_size, 1)
|
|
input_ids = input_ids.repeat(batch_size, 1)
|
|
@@ -86,7 +86,7 @@ def generate(query, *, tokenizer, model,
|
|
result = []
|
|
result = []
|
|
for _ in tqdm(range(n_iters), desc=query, leave=False):
|
|
for _ in tqdm(range(n_iters), desc=query, leave=False):
|
|
output = model.model.generate_images(
|
|
output = model.model.generate_images(
|
|
- input_ids, temperature=temperature, filter_thres=filter_thres, use_cache=True)
|
|
|
|
|
|
+ input_ids, temperature=temperature, top_k=top_k, top_p=top_p, use_cache=True)
|
|
output = rearrange(output, 'b c h w -> b h w c').cpu().numpy()
|
|
output = rearrange(output, 'b c h w -> b h w c').cpu().numpy()
|
|
result.extend(output)
|
|
result.extend(output)
|
|
return result
|
|
return result
|
|
@@ -96,7 +96,8 @@ def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--queries', type=str, help='List of queries (*.txt, newline-separated)')
|
|
parser.add_argument('--queries', type=str, help='List of queries (*.txt, newline-separated)')
|
|
parser.add_argument('--temperature', type=float, help='Sampling temperature')
|
|
parser.add_argument('--temperature', type=float, help='Sampling temperature')
|
|
- parser.add_argument('--filter-thres', type=float, help='Sampling filtering threshold')
|
|
|
|
|
|
+ parser.add_argument('--top-k', type=int, default=0)
|
|
|
|
+ parser.add_argument('--top-p', type=float, default=1.0)
|
|
parser.add_argument('--model', type=str, help='DALL-E checkpoint (*.pt)')
|
|
parser.add_argument('--model', type=str, help='DALL-E checkpoint (*.pt)')
|
|
parser.add_argument('--vqgan', type=str, help='VQGAN checkpoint (*.ckpt)')
|
|
parser.add_argument('--vqgan', type=str, help='VQGAN checkpoint (*.ckpt)')
|
|
parser.add_argument('--vqgan-config', type=str, help='VQGAN config (*.yaml)')
|
|
parser.add_argument('--vqgan-config', type=str, help='VQGAN config (*.yaml)')
|
|
@@ -122,14 +123,14 @@ def main():
|
|
model.model.vae = gan
|
|
model.model.vae = gan
|
|
model = model.cuda()
|
|
model = model.cuda()
|
|
|
|
|
|
- clip_model, clip_preprocess = clip.load("ViT-L/14", device='cuda')
|
|
|
|
|
|
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device='cuda')
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
print(f'[*] Saving results to `{args.output_dir}`')
|
|
print(f'[*] Saving results to `{args.output_dir}`')
|
|
|
|
|
|
for query in tqdm(queries):
|
|
for query in tqdm(queries):
|
|
images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8,
|
|
images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8,
|
|
- temperature=args.temperature, filter_thres=args.filter_thres)
|
|
|
|
|
|
+ temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)
|
|
|
|
|
|
images_for_clip = torch.cat([clip_preprocess(Image.fromarray((img * 255).astype(np.uint8))).unsqueeze(0).cuda() for img in images])
|
|
images_for_clip = torch.cat([clip_preprocess(Image.fromarray((img * 255).astype(np.uint8))).unsqueeze(0).cuda() for img in images])
|
|
text = clip.tokenize([query]).cuda()
|
|
text = clip.tokenize([query]).cuda()
|