control.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import os
  2. import struct
  3. from enum import Enum
  4. from typing import Optional
  5. import numpy as np
  6. import torch
  7. from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
  8. logger = get_logger(__name__)
  9. class AveragingStage(Enum):
  10. IDLE = 0 # still initializing
  11. LOOKING_FOR_GROUP = 1 # running decentralized matchmaking, can't run allreduce yet
  12. AWAITING_TRIGGER = 2 # waiting for user to set the trigger that allows running allreduce
  13. RUNNING_ALLREDUCE = 3 # exchanging tensors with groupmates
  14. FINISHED = 4 # either done or failed with exception
  15. class StepControl(MPFuture):
  16. """
  17. An auxiliary data structure that allows user to control stages and track progress in a single averaging step
  18. :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
  19. :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
  20. :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
  21. :param weight: averaging weight, can be changed afterwards
  22. :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
  23. """
  24. # indices for the shared buffer
  25. _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
  26. def __init__(
  27. self,
  28. scheduled_time: DHTExpiration,
  29. deadline: float,
  30. allow_retries: bool,
  31. weight: float,
  32. data_for_gather: bytes,
  33. ):
  34. super().__init__()
  35. self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
  36. self._trigger: Optional[MPFuture] = None
  37. self._cancel: Optional[MPFuture] = None
  38. # Buffer contents:
  39. # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
  40. self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
  41. self.stage = AveragingStage.IDLE
  42. self.scheduled_time = scheduled_time
  43. self.weight = weight
  44. self.began_allreduce = False
  45. def attach(self, trigger: MPFuture, cancel: MPFuture):
  46. assert self._trigger is None and self._cancel is None, "Futures are already attached"
  47. self._trigger, self._cancel = trigger, cancel
  48. def allow_allreduce(self):
  49. """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
  50. assert self._trigger is not None, "StepControl does not have an attached trigger"
  51. if self._trigger.done():
  52. logger.warning("Trigger is already set")
  53. else:
  54. self._trigger.set_result(None)
  55. async def wait_for_trigger(self):
  56. assert self._trigger is not None, "StepControl does not have an attached trigger"
  57. await self._trigger
  58. @property
  59. def triggered(self) -> bool:
  60. assert self._trigger is not None, "StepControl does not have an attached trigger"
  61. return self._trigger.done()
  62. @property
  63. def scheduled_time(self) -> DHTExpiration:
  64. return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
  65. @scheduled_time.setter
  66. def scheduled_time(self, scheduled_time):
  67. if self.began_allreduce:
  68. logger.warning("Changing scheduled time has no effect after all-reduce has already started")
  69. if scheduled_time >= self.deadline:
  70. logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout")
  71. struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
  72. @property
  73. def weight(self) -> float:
  74. return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
  75. @weight.setter
  76. def weight(self, weight: float):
  77. assert weight >= 0 and np.isfinite(weight)
  78. if self.began_allreduce:
  79. logger.warning("Changing weights has no effect after all-reduce has already started")
  80. struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
  81. @property
  82. def stage(self) -> AveragingStage:
  83. return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
  84. @stage.setter
  85. def stage(self, stage: AveragingStage):
  86. if stage == AveragingStage.RUNNING_ALLREDUCE:
  87. self.began_allreduce = True
  88. self._shared_buffer[StepControl._STAGE] = stage.value
  89. @property
  90. def began_allreduce(self) -> bool:
  91. return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
  92. @began_allreduce.setter
  93. def began_allreduce(self, value: bool):
  94. self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
  95. @property
  96. def data_for_gather(self) -> bytes:
  97. return self._data_for_gather
  98. @property
  99. def deadline(self) -> DHTExpiration:
  100. return self._deadline
  101. def get_timeout(self) -> Optional[DHTExpiration]:
  102. return max(0.0, self.deadline - get_dht_time())
  103. @property
  104. def allow_retries(self) -> bool:
  105. return self._allow_retries
  106. def __getstate__(self):
  107. return dict(
  108. super().__getstate__(),
  109. _trigger=self._trigger,
  110. _cancel=self._cancel,
  111. _shared_buffer=self._shared_buffer,
  112. immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
  113. )
  114. def __setstate__(self, state):
  115. super().__setstate__(state)
  116. self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
  117. self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
  118. def __del__(self):
  119. if os.getpid() == self._origin_pid and not self.triggered:
  120. logger.warning(
  121. "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
  122. "peers to fail an averaging round via TimeoutError."
  123. )
  124. super().__del__()
  125. def cancel(self) -> bool:
  126. if self._trigger is not None:
  127. self._trigger.cancel()
  128. if self._cancel is not None:
  129. self._cancel.set_result(None)
  130. return super().cancel()
  131. async def wait_for_cancel(self):
  132. """Await for step to be cancelled by the user. Should be called from insider the averager."""
  133. await self._cancel