justheuristic 3 жил өмнө
parent
commit
13a4cb3b5d

+ 29 - 19
hivemind/averaging/control.py

@@ -5,7 +5,10 @@ from typing import Optional
 import numpy as np
 import torch
 
-from hivemind.utils import MPFuture, DHTExpiration
+from hivemind.utils import MPFuture, DHTExpiration, get_logger
+
+
+logger = get_logger(__file__)
 
 
 class AveragingStage(Enum):
@@ -26,28 +29,31 @@ class StepControl(MPFuture):
 
 
     """
-    def __init__(self, scheduled_time: DHTExpiration, weight: float, wait_for_trigger: bool,
-                 gather_binary: bytes, timeout: Optional[float], allow_retries: bool):
+
+    def __init__(self, scheduled_time: DHTExpiration, deadline: Optional[float], allow_retries: bool,
+                 weight: float, gather_binary: bytes):
         super().__init__()
-        self._gather_binary, self._timeout, self._allow_retries = gather_binary, timeout, allow_retries
+        self._gather_binary, self._deadline, self._allow_retries = gather_binary, deadline, allow_retries
         self._trigger: Optional[MPFuture] = None
-        if not wait_for_trigger:
-            self.allow_allreduce()
         self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
         self.stage = AveragingStage.IDLE
         self.scheduled_time = scheduled_time
         self.weight = weight
-        self.can_modify = True
+        self.began_allreduce = False
 
-    def _attach_trigger(self, trigger: MPFuture):
-        assert self._trigger is None
+    def attach_trigger(self, trigger: MPFuture):
+        assert self._trigger is None, "trigger is already attached"
         self._trigger = trigger
 
     def allow_allreduce(self):
-        """Allows averager to begin allreduce when it finds a group."""
+        """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
+        assert self._trigger is not None, "StepControl does not have an attached trigger (not properly initialized)"
+        if self._trigger.done():
+            logger.warning("Trigger is already set")
         self._trigger.set_result(None)
 
     async def wait_for_trigger(self):
+        assert self._trigger is not None, "StepControl does not have an attached trigger (not properly initialized)"
         await self._trigger
 
     @property
@@ -56,8 +62,10 @@ class StepControl(MPFuture):
 
     @scheduled_time.setter
     def scheduled_time(self, scheduled_time):
-        assert self.can_modify, "cannot change scheduling after all-reduce has already started"
-        #TODO check that scheduled time is still within timeout
+        if self.began_allreduce:
+            logger.warning("Changing scheduled time has no effect after all-reduce has already started")
+        if scheduled_time >= self.deadline:
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
         struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
 
     @property
@@ -66,8 +74,9 @@ class StepControl(MPFuture):
 
     @weight.setter
     def weight(self, weight: float):
-        assert self.can_modify, "cannot change weights after all-reduce has already started"
         assert weight >= 0 and np.isfinite(weight)
+        if self.began_allreduce:
+            logger.warning("Changing weights has no effect after all-reduce has already started")
         struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
 
     @property
@@ -81,11 +90,11 @@ class StepControl(MPFuture):
         self._metadata[16] = stage.value
 
     @property
-    def can_modify(self) -> bool:
+    def began_allreduce(self) -> bool:
         return bool(self._metadata[17].item())
 
-    @can_modify.setter
-    def can_modify(self, value: bool):
+    @began_allreduce.setter
+    def began_allreduce(self, value: bool):
         self._metadata[17] = int(value)
 
     @property
@@ -93,13 +102,14 @@ class StepControl(MPFuture):
         return self._gather_binary
 
     @property
-    def timeout(self) -> DHTExpiration:
-        return self.timeout
+    def deadline(self) -> DHTExpiration:
+        return self._deadline
 
     @property
     def allow_retries(self) -> bool:
         return self._allow_retries
 
     def cancel(self) -> bool:
-        self._trigger.cancel()
+        if self._trigger is not None:
+            self._trigger.cancel()
         return self.cancel()