utils.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Dict, List, Tuple
  2. from multiaddr import Multiaddr
  3. from pydantic import BaseModel, StrictFloat, confloat, conint
  4. from hivemind import choose_ip_address
  5. from hivemind.dht.crypto import RSASignatureValidator
  6. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  7. from hivemind.dht.validation import RecordValidatorBase
  8. from hivemind.utils.logging import TextStyle, get_logger
  9. logger = get_logger(__name__)
  10. class LocalMetrics(BaseModel):
  11. step: conint(ge=0, strict=True)
  12. samples_per_second: confloat(ge=0.0, strict=True)
  13. samples_accumulated: conint(ge=0, strict=True)
  14. loss: StrictFloat
  15. mini_steps: conint(ge=0, strict=True)
  16. class MetricSchema(BaseModel):
  17. metrics: Dict[BytesWithPublicKey, LocalMetrics]
  18. def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
  19. signature_validator = RSASignatureValidator()
  20. validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
  21. return validators, signature_validator.local_public_key
  22. def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
  23. if only_p2p:
  24. unique_addrs = {addr["p2p"] for addr in visible_maddrs}
  25. initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
  26. else:
  27. available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
  28. if available_ips:
  29. preferred_ip = choose_ip_address(available_ips)
  30. selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
  31. else:
  32. selected_maddrs = visible_maddrs
  33. initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
  34. logger.info(
  35. f"Running a DHT peer. To connect other peers to this one over the Internet, use "
  36. f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
  37. )
  38. logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")