|
@@ -127,7 +127,8 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
|
|
|
@_state.setter
|
|
|
def _state(self, new_state: State):
|
|
|
- self._shared_state_code[...] = ALL_STATES.index(new_state)
|
|
|
+ with torch.inference_mode():
|
|
|
+ self._shared_state_code[...] = ALL_STATES.index(new_state)
|
|
|
if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
|
|
|
self._set_event_threadsafe()
|
|
|
|