|
@@ -5,7 +5,10 @@ from typing import Optional
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
-from hivemind.utils import MPFuture, DHTExpiration
|
|
|
+from hivemind.utils import MPFuture, DHTExpiration, get_logger
|
|
|
+
|
|
|
+
|
|
|
+logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
class AveragingStage(Enum):
|
|
@@ -26,28 +29,31 @@ class StepControl(MPFuture):
|
|
|
|
|
|
|
|
|
"""
|
|
|
- def __init__(self, scheduled_time: DHTExpiration, weight: float, wait_for_trigger: bool,
|
|
|
- gather_binary: bytes, timeout: Optional[float], allow_retries: bool):
|
|
|
+
|
|
|
+ def __init__(self, scheduled_time: DHTExpiration, deadline: Optional[float], allow_retries: bool,
|
|
|
+ weight: float, gather_binary: bytes):
|
|
|
super().__init__()
|
|
|
- self._gather_binary, self._timeout, self._allow_retries = gather_binary, timeout, allow_retries
|
|
|
+ self._gather_binary, self._deadline, self._allow_retries = gather_binary, deadline, allow_retries
|
|
|
self._trigger: Optional[MPFuture] = None
|
|
|
- if not wait_for_trigger:
|
|
|
- self.allow_allreduce()
|
|
|
self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
|
|
|
self.stage = AveragingStage.IDLE
|
|
|
self.scheduled_time = scheduled_time
|
|
|
self.weight = weight
|
|
|
- self.can_modify = True
|
|
|
+ self.began_allreduce = False
|
|
|
|
|
|
- def _attach_trigger(self, trigger: MPFuture):
|
|
|
- assert self._trigger is None
|
|
|
+ def attach_trigger(self, trigger: MPFuture):
|
|
|
+ assert self._trigger is None, "trigger is already attached"
|
|
|
self._trigger = trigger
|
|
|
|
|
|
def allow_allreduce(self):
|
|
|
- """Allows averager to begin allreduce when it finds a group."""
|
|
|
+ """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
|
|
|
+ assert self._trigger is not None, "StepControl does not have an attached trigger (not properly initialized)"
|
|
|
+ if self._trigger.done():
|
|
|
+ logger.warning("Trigger is already set")
|
|
|
self._trigger.set_result(None)
|
|
|
|
|
|
async def wait_for_trigger(self):
|
|
|
+ assert self._trigger is not None, "StepControl does not have an attached trigger (not properly initialized)"
|
|
|
await self._trigger
|
|
|
|
|
|
@property
|
|
@@ -56,8 +62,10 @@ class StepControl(MPFuture):
|
|
|
|
|
|
@scheduled_time.setter
|
|
|
def scheduled_time(self, scheduled_time):
|
|
|
- assert self.can_modify, "cannot change scheduling after all-reduce has already started"
|
|
|
- #TODO check that scheduled time is still within timeout
|
|
|
+ if self.began_allreduce:
|
|
|
+ logger.warning("Changing scheduled time has no effect after all-reduce has already started")
|
|
|
+ if scheduled_time >= self.deadline:
|
|
|
+ logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
|
|
|
struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
|
|
|
|
|
|
@property
|
|
@@ -66,8 +74,9 @@ class StepControl(MPFuture):
|
|
|
|
|
|
@weight.setter
|
|
|
def weight(self, weight: float):
|
|
|
- assert self.can_modify, "cannot change weights after all-reduce has already started"
|
|
|
assert weight >= 0 and np.isfinite(weight)
|
|
|
+ if self.began_allreduce:
|
|
|
+ logger.warning("Changing weights has no effect after all-reduce has already started")
|
|
|
struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
|
|
|
|
|
|
@property
|
|
@@ -81,11 +90,11 @@ class StepControl(MPFuture):
|
|
|
self._metadata[16] = stage.value
|
|
|
|
|
|
@property
|
|
|
- def can_modify(self) -> bool:
|
|
|
+ def began_allreduce(self) -> bool:
|
|
|
return bool(self._metadata[17].item())
|
|
|
|
|
|
- @can_modify.setter
|
|
|
- def can_modify(self, value: bool):
|
|
|
+ @began_allreduce.setter
|
|
|
+ def began_allreduce(self, value: bool):
|
|
|
self._metadata[17] = int(value)
|
|
|
|
|
|
@property
|
|
@@ -93,13 +102,14 @@ class StepControl(MPFuture):
|
|
|
return self._gather_binary
|
|
|
|
|
|
@property
|
|
|
- def timeout(self) -> DHTExpiration:
|
|
|
- return self.timeout
|
|
|
+ def deadline(self) -> DHTExpiration:
|
|
|
+ return self._deadline
|
|
|
|
|
|
@property
|
|
|
def allow_retries(self) -> bool:
|
|
|
return self._allow_retries
|
|
|
|
|
|
def cancel(self) -> bool:
|
|
|
- self._trigger.cancel()
|
|
|
+ if self._trigger is not None:
|
|
|
+ self._trigger.cancel()
|
|
|
return self.cancel()
|