expert_uid.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import random
  2. import re
  3. from typing import List, NamedTuple, Optional, Tuple, Union
  4. import hivemind
  5. from hivemind.dht import DHT
  6. from hivemind.utils import Endpoint, get_logger
  7. logger = get_logger(__name__)
  8. ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
  9. UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
  10. UID_DELIMITER = "." # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
  11. FLAT_EXPERT = -1 # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
  12. UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$") # e.g. ffn_expert.98.76.54 - prefix + some dims
  13. PREFIX_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$") # e.g. expert. or ffn.45. (ends with ".")
  14. # formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
  15. def is_valid_uid(maybe_uid: str) -> bool:
  16. """An uid must contain a string expert type, followed by one or more .-separated numeric indices"""
  17. return bool(UID_PATTERN.fullmatch(maybe_uid))
  18. def is_valid_prefix(maybe_prefix: str) -> bool:
  19. """An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period"""
  20. return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
  21. def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
  22. """Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate"""
  23. uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
  24. pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
  25. return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
  26. def generate_uids_from_pattern(
  27. num_experts: int, expert_pattern: Optional[str], dht: Optional[DHT] = None, attempts_per_expert=10
  28. ) -> List[str]:
  29. """
  30. Sample experts from a given pattern, remove duplicates.
  31. :param num_experts: sample this many unique expert uids
  32. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  33. means "sample random experts between myprefix.0.0 and myprefix.255.255;
  34. :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
  35. :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
  36. :note: this method is not strictly process-safe. If several servers run it concurrently, they have
  37. a small chance of sampling duplicate expert uids.
  38. """
  39. remaining_attempts = attempts_per_expert * num_experts
  40. found_uids, attempted_uids = list(), set()
  41. def _generate_uid():
  42. if expert_pattern is None:
  43. return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
  44. uid = []
  45. for block in expert_pattern.split(UID_DELIMITER):
  46. try:
  47. if "[" not in block and "]" not in block:
  48. uid.append(block)
  49. elif block.startswith("[") and block.endswith("]") and ":" in block:
  50. slice_start, slice_end = map(int, block[1:-1].split(":"))
  51. uid.append(str(random.randint(slice_start, slice_end - 1)))
  52. else:
  53. raise ValueError("Block must be either fixed or a range [from:to]")
  54. except KeyboardInterrupt:
  55. raise
  56. except Exception as e:
  57. raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
  58. return UID_DELIMITER.join(uid)
  59. while remaining_attempts > 0 and len(found_uids) < num_experts:
  60. # 1. sample new expert uids at random
  61. new_uids = []
  62. while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
  63. new_uid = _generate_uid()
  64. remaining_attempts -= 1
  65. if new_uid not in attempted_uids:
  66. attempted_uids.add(new_uid)
  67. new_uids.append(new_uid)
  68. # 2. look into DHT (if given) and remove duplicates
  69. if dht:
  70. existing_expert_uids = {
  71. found_expert.uid
  72. for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
  73. if found_expert is not None
  74. }
  75. new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
  76. found_uids += new_uids
  77. if len(found_uids) != num_experts:
  78. logger.warning(
  79. f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
  80. f"{attempts_per_expert * num_experts} attempts"
  81. )
  82. return found_uids