Forráskód Böngészése

make it compatible with low_cpu_mem_usage=True

dbaranchuk 3 éve
szülő
commit
66ce6d9669
3 módosított fájl, 12 hozzáadás és 8 törlés
  1. 7 3
      src/bloom/from_pretrained.py
  2. 3 4
      src/bloom/model.py
  3. 2 1
      src/client/remote_model.py

+ 7 - 3
src/bloom/from_pretrained.py

@@ -34,12 +34,13 @@ def load_pretrained_block(
     config: Optional[BloomConfig] = None,
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None
 ) -> BloomBlock:
     """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     block = BloomBlock(config, layer_number=block_index)
-    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
+    state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir)
     block.load_state_dict(state_dict)
 
     if torch_dtype == "auto":
@@ -57,7 +58,10 @@ def load_pretrained_block(
 
 
 def _load_state_dict(
-    pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
+    pretrained_model_name_or_path: str, 
+    block_index: Optional[int] = None, 
+    use_auth_token: Optional[str] = None, 
+    cache_dir: Optional[str] = None
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
     archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@@ -65,7 +69,7 @@ def _load_state_dict(
     # Load from URL or cache if already cached
     resolved_archive_file = cached_path(
         archive_file,
-        cache_dir=None,
+        cache_dir=cache_dir,
         force_download=FORCE_DOWNLOAD,
         proxies=None,
         resume_download=RESUME_DOWNLOAD,

+ 3 - 4
src/bloom/model.py

@@ -156,9 +156,7 @@ class BloomModel(BloomPreTrainedModel):
         self.n_head = config.n_head
 
         # Embedding + LN Embedding
-
-        # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  # dtype=config.torch_dtype
+        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)  
         self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
         # Transformer blocks
@@ -229,7 +227,8 @@ class BloomModel(BloomPreTrainedModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
 
         output_shape = input_shape + (hidden_states.size(-1),)
 

+ 2 - 1
src/client/remote_model.py

@@ -90,7 +90,8 @@ class DistributedBloomModel(BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        # Note: it supports only float32 or bfloat16 inputs
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
         hidden_states = self.h(hidden_states)