Explorar o código

Fix MPFuture failing outside inference mode (#521)

Alexander Borzunov %!s(int64=2) %!d(string=hai) anos
pai
achega
8f258b4b36
Modificáronse 1 ficheiros con 2 adicións e 1 borrados
  1. 2 1
      hivemind/utils/mpfuture.py

+ 2 - 1
hivemind/utils/mpfuture.py

@@ -127,7 +127,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     @_state.setter
     def _state(self, new_state: State):
-        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        with torch.inference_mode():
+            self._shared_state_code[...] = ALL_STATES.index(new_state)
         if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
             self._set_event_threadsafe()