utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import Dict, List, Tuple
  2. from multiaddr import Multiaddr
  3. from pydantic import BaseModel, StrictFloat, confloat, conint
  4. from hivemind.dht.crypto import RSASignatureValidator
  5. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  6. from hivemind.dht.validation import RecordValidatorBase
  7. from hivemind.utils.logging import get_logger
  8. logger = get_logger(__name__)
  9. class LocalMetrics(BaseModel):
  10. step: conint(ge=0, strict=True)
  11. samples_per_second: confloat(ge=0.0, strict=True)
  12. samples_accumulated: conint(ge=0, strict=True)
  13. loss: StrictFloat
  14. mini_steps: conint(ge=0, strict=True)
  15. class MetricSchema(BaseModel):
  16. metrics: Dict[BytesWithPublicKey, LocalMetrics]
  17. def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
  18. signature_validator = RSASignatureValidator()
  19. validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator]
  20. return validators, signature_validator.local_public_key
  21. class TextStyle:
  22. BOLD = "\033[1m"
  23. BLUE = "\033[34m"
  24. RESET = "\033[0m"
  25. def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
  26. if only_p2p:
  27. unique_addrs = {addr["p2p"] for addr in visible_maddrs}
  28. initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
  29. else:
  30. initial_peers_str = " ".join(str(addr) for addr in visible_maddrs)
  31. logger.info(
  32. f"Running a DHT peer. To connect other peers to this one, use "
  33. f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
  34. )