metrics_utils.py 915 B

12345678910111213141516171819202122232425
  1. from typing import Dict, List, Tuple
  2. from hivemind.dht.crypto import RSASignatureValidator
  3. from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
  4. from hivemind.dht.validation import RecordValidatorBase
  5. from pydantic import BaseModel, StrictFloat, confloat, conint
  6. class LocalMetrics(BaseModel):
  7. step: conint(ge=0, strict=True)
  8. samples_per_second: confloat(ge=0.0, strict=True)
  9. samples_accumulated: conint(ge=0, strict=True)
  10. loss: StrictFloat
  11. mini_steps: conint(ge=0, strict=True)
  12. class MetricSchema(BaseModel):
  13. metrics: Dict[BytesWithPublicKey, LocalMetrics]
  14. def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]:
  15. signature_validator = RSASignatureValidator()
  16. validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix),
  17. signature_validator]
  18. return validators, signature_validator.local_public_key