Selaa lähdekoodia

Fix typing and sort imports

Aleksandr Borzunov 4 vuotta sitten
vanhempi
commit
0cba0a07ed
2 muutettua tiedostoa jossa 6 lisäystä ja 6 poistoa
  1. 2 2
      hivemind/averaging/allreduce.py
  2. 4 4
      hivemind/averaging/matchmaking.py

+ 2 - 2
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
-from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
 from enum import Enum
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type, Union
 
 import torch
 
@@ -51,7 +51,7 @@ class AllReduceRunner(ServicerBase):
         self,
         *,
         p2p: P2P,
-        servicer: Optional[ServicerBase],
+        servicer: Optional[Union[ServicerBase, Type[ServicerBase]]],
         group_id: GroupID,
         tensors: Sequence[torch.Tensor],
         ordered_group_endpoints: Sequence[Endpoint],

+ 4 - 4
hivemind/averaging/matchmaking.py

@@ -2,12 +2,12 @@
 
 from __future__ import annotations
 
+import asyncio
+import concurrent.futures
 import contextlib
 import random
 from math import isfinite
-from typing import Optional, AsyncIterator, Set, Tuple, Dict
-import concurrent.futures
-import asyncio
+from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type, Union
 
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
@@ -36,7 +36,7 @@ class Matchmaking:
     def __init__(
         self,
         p2p: P2P,
-        servicer: ServicerBase,
+        servicer: Union[ServicerBase, Type[ServicerBase]],
         schema_hash: bytes,
         dht: DHT,
         *,