control.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import struct
  2. from enum import Enum
  3. from typing import Optional
  4. import numpy as np
  5. import torch
  6. from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
  7. logger = get_logger(__name__)
  8. class AveragingStage(Enum):
  9. IDLE = 0 # still initializing
  10. LOOKING_FOR_GROUP = 1 # running decentralized matchmaking, can't run allreduce yet
  11. AWAITING_TRIGGER = 2 # waiting for user to set the trigger that allows running allreduce
  12. RUNNING_ALLREDUCE = 3 # exchanging tensors with groupmates
  13. FINISHED = 4 # either done or failed with exception
  14. class StepControl(MPFuture):
  15. """
  16. An auxiliary data structure that allows user to control stages and track progress in a single averaging step
  17. :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
  18. :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
  19. :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
  20. :param weight: averaging weight, can be changed afterwards
  21. :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
  22. """
  23. # indices for the shared buffer
  24. _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
  25. def __init__(
  26. self,
  27. scheduled_time: DHTExpiration,
  28. deadline: float,
  29. allow_retries: bool,
  30. weight: float,
  31. data_for_gather: bytes,
  32. ):
  33. super().__init__()
  34. self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
  35. self._trigger: Optional[MPFuture] = None
  36. self._cancel: Optional[MPFuture] = None
  37. # Buffer contents:
  38. # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
  39. self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
  40. self.stage = AveragingStage.IDLE
  41. self.scheduled_time = scheduled_time
  42. self.weight = weight
  43. self.began_allreduce = False
  44. def attach(self, trigger: MPFuture, cancel: MPFuture):
  45. assert self._trigger is None and self._cancel is None, "Futures are already attached"
  46. self._trigger, self._cancel = trigger, cancel
  47. def allow_allreduce(self):
  48. """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
  49. assert self._trigger is not None, "StepControl does not have an attached trigger"
  50. if self._trigger.done():
  51. logger.warning("Trigger is already set")
  52. else:
  53. self._trigger.set_result(None)
  54. async def wait_for_trigger(self):
  55. assert self._trigger is not None, "StepControl does not have an attached trigger"
  56. await self._trigger
  57. @property
  58. def triggered(self) -> bool:
  59. assert self._trigger is not None, "StepControl does not have an attached trigger"
  60. return self._trigger.done()
  61. @property
  62. def scheduled_time(self) -> DHTExpiration:
  63. return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
  64. @scheduled_time.setter
  65. def scheduled_time(self, scheduled_time):
  66. if self.began_allreduce:
  67. logger.warning("Changing scheduled time has no effect after all-reduce has already started")
  68. if scheduled_time >= self.deadline:
  69. logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
  70. struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
  71. @property
  72. def weight(self) -> float:
  73. return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
  74. @weight.setter
  75. def weight(self, weight: float):
  76. assert weight >= 0 and np.isfinite(weight)
  77. if self.began_allreduce:
  78. logger.warning("Changing weights has no effect after all-reduce has already started")
  79. struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
  80. @property
  81. def stage(self) -> AveragingStage:
  82. return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
  83. @stage.setter
  84. def stage(self, stage: AveragingStage):
  85. if stage == AveragingStage.RUNNING_ALLREDUCE:
  86. self.began_allreduce = True
  87. self._shared_buffer[StepControl._STAGE] = stage.value
  88. @property
  89. def began_allreduce(self) -> bool:
  90. return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
  91. @began_allreduce.setter
  92. def began_allreduce(self, value: bool):
  93. self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
  94. @property
  95. def data_for_gather(self) -> bytes:
  96. return self._data_for_gather
  97. @property
  98. def deadline(self) -> DHTExpiration:
  99. return self._deadline
  100. def get_timeout(self) -> Optional[DHTExpiration]:
  101. return max(0.0, self.deadline - get_dht_time())
  102. @property
  103. def allow_retries(self) -> bool:
  104. return self._allow_retries
  105. def __getstate__(self):
  106. return dict(
  107. super().__getstate__(),
  108. _trigger=self._trigger,
  109. _cancel=self._cancel,
  110. _shared_buffer=self._shared_buffer,
  111. immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
  112. )
  113. def __setstate__(self, state):
  114. super().__setstate__(state)
  115. self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
  116. self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
  117. def cancel(self) -> bool:
  118. if self._trigger is not None:
  119. self._trigger.cancel()
  120. if self._cancel is not None:
  121. self._cancel.set_result(None)
  122. return super().cancel()
  123. async def wait_for_cancel(self):
  124. """Await for step to be cancelled by the user. Should be called from insider the averager."""
  125. await self._cancel