Your Name 2 years ago
parent
commit
1879788705
1 changed files with 3 additions and 3 deletions
  1. 3 3
      src/petals/utils/packaging.py

+ 3 - 3
src/petals/utils/packaging.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Tuple
+from typing import Any, Tuple, Sequence
 
 import torch
 from hivemind import nested_flatten, nested_pack
@@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int:
     return int(item[3:])
 
 
-def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
+def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]:
     """
     Check the function's arguments and pack all tensors into different flattened lists.
     :returns: a flattened list of tensors and args and kwargs, where tensors were masked
@@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
     return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
 
 
-def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
+def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any):
     """
     Restore arguments after `pack_args_kwargs` function.
     :returns: list of args and dict of kwargs