utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Dict, List, Tuple
  2. import transformers.utils.logging
  3. from multiaddr import Multiaddr
  4. from pydantic import BaseModel, StrictFloat, confloat, conint
  5. from hivemind import choose_ip_address
  6. from hivemind.dht.crypto import RSASignatureValidator
  7. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  8. from hivemind.dht.validation import RecordValidatorBase
  9. from hivemind.utils.logging import get_logger
  10. from transformers.trainer_utils import is_main_process
  11. logger = get_logger(__name__)
  12. class LocalMetrics(BaseModel):
  13. step: conint(ge=0, strict=True)
  14. samples_per_second: confloat(ge=0.0, strict=True)
  15. samples_accumulated: conint(ge=0, strict=True)
  16. loss: StrictFloat
  17. mini_steps: conint(ge=0, strict=True)
  18. class MetricSchema(BaseModel):
  19. metrics: Dict[BytesWithPublicKey, LocalMetrics]
  20. def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
  21. signature_validator = RSASignatureValidator()
  22. validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
  23. return validators, signature_validator.local_public_key
  24. class TextStyle:
  25. BOLD = "\033[1m"
  26. BLUE = "\033[34m"
  27. RESET = "\033[0m"
  28. def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
  29. if only_p2p:
  30. unique_addrs = {addr["p2p"] for addr in visible_maddrs}
  31. initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
  32. else:
  33. available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
  34. if available_ips:
  35. preferred_ip = choose_ip_address(available_ips)
  36. selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
  37. else:
  38. selected_maddrs = visible_maddrs
  39. initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
  40. logger.info(
  41. f"Running a DHT peer. To connect other peers to this one over the Internet, use "
  42. f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
  43. )
  44. logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")
  45. def setup_logging(training_args):
  46. # Log on each process the small summary:
  47. logger.warning(
  48. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  49. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  50. )
  51. # Set the verbosity to info of the Transformers logger (on main process only):
  52. if is_main_process(training_args.local_rank):
  53. transformers.utils.logging.set_verbosity_info()
  54. transformers.utils.logging.enable_default_handler()
  55. transformers.utils.logging.enable_explicit_format()
  56. logger.info("Training/evaluation parameters %s", training_args)