Переглянути джерело

Debug mode: load empty block

Aleksandr Borzunov 2 роки тому
батько
коміт
2cfd70d751
1 змінених файлів з 4 додано та 24 видалено
  1. 4 24
      src/petals/bloom/from_pretrained.py

+ 4 - 24
src/petals/bloom/from_pretrained.py

@@ -47,30 +47,10 @@ def load_pretrained_block(
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
 
-    with init_empty_weights():
-        block = WrappedBloomBlock(config)
-
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-        max_disk_space=max_disk_space,
-    )
-
-    # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=True)
-    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
-
-    for param_name, _ in block.named_parameters():
-        assert param_name in state_dict, f"{param_name} not in state dict"
-        param = state_dict[param_name]
-        if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
-            param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
-
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
+    block = WrappedBloomBlock(config).to(torch.bfloat16)
+
+    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}")
+    logger.warning(f"Debug mode: loaded empty block of type {set(param.dtype for param in block.parameters())}")
     return block