|
@@ -68,7 +68,7 @@ def load_pretrained_block(
|
|
|
param = state_dict[param_name]
|
|
|
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
|
param = param.to(torch_dtype)
|
|
|
- set_module_tensor_to_device(block, param_name, "cpu", value=param)
|
|
|
+ set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
|
|
|
|
|
|
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
|
return block
|