utils.py 953 B

12345678910111213141516171819202122232425262728
  1. from typing import Dict, List, Tuple
  2. from pydantic import BaseModel, StrictFloat, confloat, conint
  3. from hivemind.dht.crypto import RSASignatureValidator
  4. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  5. from hivemind.dht.validation import RecordValidatorBase
  6. from hivemind.utils.logging import get_logger
  7. logger = get_logger(__name__)
  8. class LocalMetrics(BaseModel):
  9. step: conint(ge=0, strict=True)
  10. samples_per_second: confloat(ge=0.0, strict=True)
  11. samples_accumulated: conint(ge=0, strict=True)
  12. loss: StrictFloat
  13. mini_steps: conint(ge=0, strict=True)
  14. class MetricSchema(BaseModel):
  15. metrics: Dict[BytesWithPublicKey, LocalMetrics]
  16. def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
  17. signature_validator = RSASignatureValidator()
  18. validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
  19. return validators, signature_validator.local_public_key