|
@@ -52,7 +52,7 @@ def load_pretrained_block(
|
|
|
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
|
|
-
|
|
|
+ print("block_index",block_index)
|
|
|
with init_empty_weights():
|
|
|
block = get_model_block(config, layer_idx=block_index)
|
|
|
|