|
@@ -1,8 +1,6 @@
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.pool
|
|
|
-from concurrent.futures import as_completed, TimeoutError, Future
|
|
|
from functools import partial
|
|
|
-from itertools import chain
|
|
|
from typing import Tuple, List, Dict, Any, Optional
|
|
|
|
|
|
import numpy as np
|
|
@@ -11,8 +9,7 @@ import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
from .expert import RemoteExpert, _RemoteModuleCall
|
|
|
-from ..utils import nested_map, check_numpy, run_in_background, run_and_await_k, nested_pack, BatchTensorProto, \
|
|
|
- nested_flatten, DUMMY
|
|
|
+from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
|
|
|
from ..utils.autograd import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward
|
|
|
|
|
|
|