Selaa lähdekoodia

add LM head for DistributedBloomCausalLM

dbaranchuk 3 vuotta sitten
vanhempi
commit
b3cc9e0d99
1 muutettua tiedostoa jossa 4 lisäystä ja 4 poistoa
  1. 4 4
      src/client/remote_model.py

+ 4 - 4
src/client/remote_model.py

@@ -1,12 +1,11 @@
 # this code is in active development, interfaces may change
 import os
-from typing import Optional, Tuple, Union
+from typing import Optional, Tuple
 
 import hivemind
-from hivemind import DHT, get_logger, use_hivemind_log_handler
+from hivemind import get_logger, use_hivemind_log_handler
 
-from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHeadForCausalLM
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
@@ -55,5 +54,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
     def __init__(self, config: DistributedBloomConfig):
         BloomPreTrainedModel.__init__(self, config)
         self.transformer = DistributedBloomModel(config)
+        self.lm_head = LMHeadForCausalLM(config)
         # Initialize weights and apply final processing
         self.post_init()