Browse Source

Add filter_thres argument to inference script

Aleksandr Borzunov 3 years ago
parent
commit
c936ca21da
1 changed files with 4 additions and 2 deletions
  1. 4 2
      inference/run_inference.py

+ 4 - 2
inference/run_inference.py

@@ -77,7 +77,7 @@ def make_model():
 
 
 def generate(query, *, tokenizer, model,
-             batch_size=16, n_iters=1, temperature=0.5, filter_thres=0.5):
+             batch_size, n_iters, temperature, filter_thres):
     input_ids = torch.tensor(tokenizer(query, add_special_tokens=False, max_length=256, truncation=True)['input_ids'])
     input_ids = F.pad(input_ids, (0, 256 - len(input_ids)), value=1)
     input_ids = input_ids.repeat(batch_size, 1)
@@ -96,6 +96,7 @@ def main():
     parser = argparse.ArgumentParser()
     parser.add_argument('--queries', type=str, help='List of queries (*.txt, newline-separated)')
     parser.add_argument('--temperature', type=float, help='Sampling temperature')
+    parser.add_argument('--filter-thres', type=float, help='Sampling filtering threshold')
     parser.add_argument('--model', type=str, help='DALL-E checkpoint (*.pt)')
     parser.add_argument('--vqgan', type=str, help='VQGAN checkpoint (*.ckpt)')
     parser.add_argument('--vqgan-config', type=str, help='VQGAN config (*.yaml)')
@@ -127,7 +128,8 @@ def main():
     print(f'[*] Saving results to `{args.output_dir}`')
 
     for query in tqdm(queries):
-        images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8, temperature=args.temperature)
+        images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8,
+                          temperature=args.temperature, filter_thres=args.filter_thres)
 
         images_for_clip = torch.cat([clip_preprocess(Image.fromarray((img * 255).astype(np.uint8))).unsqueeze(0).cuda() for img in images])
         text = clip.tokenize([query]).cuda()