task_prioritizer.py 731 B

1234567891011121314151617181920
  1. from abc import ABC, abstractmethod
  2. import torch
  3. from hivemind.moe.server.task_pool import Task
  4. class TaskPrioritizerBase(ABC):
  5. """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
  6. @abstractmethod
  7. def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
  8. """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
  9. pass
  10. class DummyTaskPrioritizer(TaskPrioritizerBase):
  11. """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
  12. def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
  13. return 0.0