config.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import os
  2. from typing import Optional, Union
  3. from hivemind import get_logger
  4. from transformers.models.falcon import FalconConfig
  5. from transformers.models.falcon.modeling_falcon import FalconAttention
  6. from petals.client.config import ClientConfig
  7. from petals.client.lm_head import LMHeadConfig
  8. from petals.client.ptune import PTuneConfig
  9. from petals.models.falcon.block import WrappedFalconBlock
  10. from petals.utils.auto_config import DefaultRevisionMixin
  11. logger = get_logger(__name__)
  12. class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
  13. block_class = WrappedFalconBlock
  14. attn_class = FalconAttention
  15. block_prefix = "transformer.h"
  16. @property
  17. def num_key_value_groups(self) -> int:
  18. if self.new_decoder_architecture:
  19. return self.num_attention_heads // self.num_kv_heads
  20. if self.multi_query:
  21. return self.num_attention_heads
  22. return 1
  23. @classmethod
  24. def from_pretrained(
  25. cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
  26. ):
  27. loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
  28. if loading_from_repo and dht_prefix is None:
  29. dht_prefix = str(model_name_or_path)
  30. dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
  31. dht_prefix = dht_prefix.replace(".", "-")
  32. logger.info(f"Using DHT prefix: {dht_prefix}")
  33. result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
  34. config = result[0] if isinstance(result, tuple) else result
  35. if config.pad_token_id is None:
  36. config.pad_token_id = 0
  37. return result