|
@@ -47,30 +47,10 @@ def load_pretrained_block(
|
|
|
if cache_dir is None:
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
|
- with init_empty_weights():
|
|
|
- block = WrappedBloomBlock(config)
|
|
|
-
|
|
|
- state_dict = _load_state_dict(
|
|
|
- converted_model_name_or_path,
|
|
|
- block_index,
|
|
|
- config,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- max_disk_space=max_disk_space,
|
|
|
- )
|
|
|
-
|
|
|
- # dummy load, check that keys match
|
|
|
- report = block.load_state_dict(state_dict, strict=True)
|
|
|
- assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
|
|
-
|
|
|
- for param_name, _ in block.named_parameters():
|
|
|
- assert param_name in state_dict, f"{param_name} not in state dict"
|
|
|
- param = state_dict[param_name]
|
|
|
- if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
|
- param = param.to(torch_dtype)
|
|
|
- set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
|
|
-
|
|
|
- logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
|
+ block = WrappedBloomBlock(config).to(torch.bfloat16)
|
|
|
+
|
|
|
+ logger.info(f"Loaded {converted_model_name_or_path} block {block_index}")
|
|
|
+ logger.warning(f"Debug mode: loaded empty block of type {set(param.dtype for param in block.parameters())}")
|
|
|
return block
|
|
|
|
|
|
|