expert_uid.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import random
  2. import re
  3. from typing import List, NamedTuple, Optional, Tuple, Union
  4. import hivemind
  5. from hivemind.dht import DHT
  6. from hivemind.utils import Endpoint, get_logger
  7. logger = get_logger(__name__)
  8. ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
  9. UidEndpoint = NamedTuple("UidEndpoint", [("uid", ExpertUID), ("endpoint", Endpoint)])
  10. UID_DELIMITER = "." # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
  11. FLAT_EXPERT = -1 # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
  12. UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$") # e.g. ffn_expert.98.76.54 - prefix + some dims
  13. PREFIX_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$") # e.g. expert. or ffn.45. (ends with ".")
  14. # formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
  15. def is_valid_uid(maybe_uid: str) -> bool:
  16. """An uid must contain a string expert type, followed by one or more .-separated numeric indices"""
  17. return bool(UID_PATTERN.fullmatch(maybe_uid))
  18. def is_valid_prefix(maybe_prefix: str) -> bool:
  19. """An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period"""
  20. return bool(PREFIX_PATTERN.fullmatch(maybe_prefix))
  21. def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]:
  22. """Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate"""
  23. uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER)
  24. pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1
  25. return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:])
  26. def generate_uids_from_pattern(
  27. num_experts: int,
  28. expert_pattern: Optional[str],
  29. dht: Optional[DHT] = None,
  30. attempts_per_expert=10,
  31. ) -> List[str]:
  32. """
  33. Sample experts from a given pattern, optionally remove duplicates.
  34. :param num_experts: sample this many unique expert uids
  35. :param expert_pattern: a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256]\
  36. means "sample random experts between myprefix.0.0 and myprefix.255.255"
  37. :param dht: if specified, uses this DHT to check that expert uids are not yet occupied by other peers
  38. :param dht: whether to exclude expert uids that are already present in the DHT
  39. (you may disable it if you want to have the same expert on multiple peers)
  40. :param attempts_per_expert: give up if unable to generate a new expert uid after this many attempts per uid
  41. :note: this method is not strictly process-safe. If several servers run it concurrently, they have
  42. a small chance of sampling duplicate expert uids.
  43. """
  44. remaining_attempts = attempts_per_expert * num_experts
  45. found_uids, attempted_uids = list(), set()
  46. def _generate_uid():
  47. if expert_pattern is None:
  48. return f"expert{UID_DELIMITER}{attempts_per_expert * num_experts - remaining_attempts}"
  49. uid = []
  50. for block in expert_pattern.split(UID_DELIMITER):
  51. try:
  52. if "[" not in block and "]" not in block:
  53. uid.append(block)
  54. elif block.startswith("[") and block.endswith("]") and ":" in block:
  55. slice_start, slice_end = map(int, block[1:-1].split(":"))
  56. uid.append(str(random.randint(slice_start, slice_end - 1)))
  57. else:
  58. raise ValueError("Block must be either fixed or a range [from:to]")
  59. except KeyboardInterrupt:
  60. raise
  61. except Exception as e:
  62. raise ValueError(f"Expert pattern {expert_pattern} has invalid block {block}, {e}")
  63. return UID_DELIMITER.join(uid)
  64. while remaining_attempts > 0 and len(found_uids) < num_experts:
  65. # sample new expert uids at random
  66. new_uids = []
  67. while len(new_uids) + len(found_uids) < num_experts and remaining_attempts > 0:
  68. new_uid = _generate_uid()
  69. remaining_attempts -= 1
  70. if new_uid not in attempted_uids:
  71. attempted_uids.add(new_uid)
  72. new_uids.append(new_uid)
  73. if dht:
  74. existing_expert_uids = {
  75. found_expert.uid
  76. for found_expert in hivemind.moe.server.get_experts(dht, new_uids)
  77. if found_expert is not None
  78. }
  79. new_uids = [new_uid for new_uid in new_uids if new_uid not in existing_expert_uids]
  80. found_uids += new_uids
  81. if len(found_uids) != num_experts:
  82. logger.warning(
  83. f"Found only {len(found_uids)} out of {num_experts} free expert uids after "
  84. f"{attempts_per_expert * num_experts} attempts"
  85. )
  86. return found_uids