|
@@ -1,15 +1,16 @@
|
|
|
import asyncio
|
|
|
-import torch
|
|
|
-import numpy as np
|
|
|
+from concurrent.futures import CancelledError
|
|
|
|
|
|
+import numpy as np
|
|
|
import pytest
|
|
|
-import hivemind
|
|
|
+import torch
|
|
|
+
|
|
|
from hivemind.proto.dht_pb2_grpc import DHTStub
|
|
|
-from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
|
|
|
-from hivemind.utils import MSGPackSerializer
|
|
|
-from concurrent.futures import CancelledError
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
-from hivemind.utils import serialize_torch_tensor, deserialize_torch_tensor
|
|
|
+from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
|
|
|
+import hivemind
|
|
|
+from hivemind.utils import MSGPackSerializer, serialize_torch_tensor, deserialize_torch_tensor
|
|
|
+from hivemind.utils.mpfuture import FutureStateError
|
|
|
|
|
|
|
|
|
def test_mpfuture_result():
|
|
@@ -19,9 +20,9 @@ def test_mpfuture_result():
|
|
|
assert f1.result() == 321
|
|
|
|
|
|
for future in [f1, f2]:
|
|
|
- with pytest.raises(RuntimeError):
|
|
|
+ with pytest.raises(FutureStateError):
|
|
|
future.set_result(123)
|
|
|
- with pytest.raises(RuntimeError):
|
|
|
+ with pytest.raises(FutureStateError):
|
|
|
future.set_exception(ValueError())
|
|
|
assert future.cancel() is False
|
|
|
assert future.done() and not future.running() and not future.cancelled()
|
|
@@ -58,9 +59,9 @@ def test_mpfuture_cancel():
|
|
|
future.result()
|
|
|
with pytest.raises(CancelledError):
|
|
|
future.exception()
|
|
|
- with pytest.raises(RuntimeError):
|
|
|
+ with pytest.raises(FutureStateError):
|
|
|
future.set_result(123)
|
|
|
- with pytest.raises(RuntimeError):
|
|
|
+ with pytest.raises(FutureStateError):
|
|
|
future.set_exception(NotImplementedError())
|
|
|
assert future.cancelled() and future.done() and not future.running()
|
|
|
|
|
@@ -204,7 +205,6 @@ def test_serialize_tensor():
|
|
|
assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar)
|
|
|
|
|
|
|
|
|
-
|
|
|
def test_serialize_tuple():
|
|
|
test_pairs = (
|
|
|
((1, 2, 3), [1, 2, 3]),
|
|
@@ -264,5 +264,5 @@ def test_generic_data_classes():
|
|
|
assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
|
|
|
|
|
|
sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
|
|
|
- sorted_heap_entry = sorted([HeapEntry(expiration_time=DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
|
|
|
- assert all([heap_entry.expiration_time == value for heap_entry, value in zip(sorted_heap_entry, sorted_expirations)])
|
|
|
+ sorted_heap_entries = sorted([HeapEntry(DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
|
|
|
+ assert all([entry.expiration_time == value for entry, value in zip(sorted_heap_entries, sorted_expirations)])
|