Browse Source

Loading a bloom block working.

Tim Dettmers 3 năm trước cách đây
mục cha
commit
43fdcac6aa
2 tập tin đã thay đổi với 16 bổ sung2 xóa
  1. 3 0
      .gitignore
  2. 13 2
      cli/inference_one_block.py

+ 3 - 0
.gitignore

@@ -126,3 +126,6 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# vim
+*.swp

+ 13 - 2
cli/inference_one_block.py

@@ -32,18 +32,29 @@ if __name__ == "__main__":
     parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
     parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
     parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
+    parser.add_argument("--block-path", default='', type=str, help="The path to the Bloom block-path")
     args = parser.parse_args()
 
     if args.device is None:
         args.device = "cuda" if torch.cuda.is_available() else "cpu"
+    print(f'Using device {args.device}')
 
     config = DistributedBloomConfig.from_json_file(args.config)
-    block = BloomBlock(config, args.layer_index).to(args.device)
+    block = BloomBlock(config, args.layer_index)
+
+    if args.block_path != '':
+        print(f'Loading block from {args.block_path}')
+        block.load_state_dict( torch.load(args.block_path))
+        #print(list(block_data.keys()))
+        #block.load(args.block_path)
+
+    block = block.to(args.device)
+    block = block.to(torch.bfloat16)
 
     cache = None
 
     for i in trange(args.num_steps):
-        dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
+        dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device).to(torch.bfloat16)
         alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
         with torch.no_grad():
             outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)