|
@@ -1,12 +1,11 @@
|
|
# this code is in active development, interfaces may change
|
|
# this code is in active development, interfaces may change
|
|
import os
|
|
import os
|
|
-from typing import Optional, Tuple, Union
|
|
|
|
|
|
+from typing import Optional, Tuple
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
-from hivemind import DHT, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
+from hivemind import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
|
|
|
|
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
|
|
|
|
|
|
+from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHeadForCausalLM
|
|
from src.client.remote_sequential import RemoteSequential
|
|
from src.client.remote_sequential import RemoteSequential
|
|
from src.data_structures import UID_DELIMITER
|
|
from src.data_structures import UID_DELIMITER
|
|
|
|
|
|
@@ -55,5 +54,6 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
def __init__(self, config: DistributedBloomConfig):
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
BloomPreTrainedModel.__init__(self, config)
|
|
self.transformer = DistributedBloomModel(config)
|
|
self.transformer = DistributedBloomModel(config)
|
|
|
|
+ self.lm_head = LMHeadForCausalLM(config)
|
|
# Initialize weights and apply final processing
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
self.post_init()
|