Przeglądaj źródła

fix problem with NamedTuple inheritance in python3.9 (#142)

* fix python3.9 problem with NamedTuple inheritance

* NamedTuple -> dataclass

* add python3.9.1 to CircleCI

* add python3.9.1 to CircleCI

* add python3.9.1 to CircleCI

* bump version of hivemind
Anton Sinitsin 4 lat temu
rodzic
commit
0d7818b2cd

+ 26 - 2
.circleci/config.yml

@@ -1,7 +1,7 @@
 version: 2.1
 
 jobs:
-  build-and-test:
+  build-and-test-py3-8-1:
     docker:
       - image: circleci/python:3.8.1
     steps:
@@ -24,8 +24,32 @@ jobs:
       - run:
           command: codecov
           name: codecov
+  build-and-test-py3-9-1:
+    docker:
+      - image: circleci/python:3.9.1
+    steps:
+      - checkout
+      - restore_cache:
+          keys:
+            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+      - run: pip install -r requirements.txt
+      - run: pip install -r requirements-dev.txt
+      - save_cache:
+          key: v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+          paths:
+            - '~/.cache/pip'
+      - run:
+          command: pip install -e .
+          name: setup
+      - run:
+          command: pytest ./tests
+          name: tests
+      - run:
+          command: codecov
+          name: codecov
 
 workflows:
   main:
     jobs:
-      - build-and-test
+      - build-and-test-py3-8-1
+      - build-and-test-py3-9-1

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.27'
+__version__ = '0.8.28'

+ 24 - 5
hivemind/utils/timed_storage.py

@@ -3,7 +3,8 @@ from __future__ import annotations
 import heapq
 import time
 from contextlib import contextmanager
-from typing import TypeVar, NamedTuple, Generic, Optional, Dict, List, Iterator, Tuple
+from typing import TypeVar, Generic, Optional, Dict, List, Iterator, Tuple
+from dataclasses import dataclass
 
 KeyType = TypeVar('KeyType')
 ValueType = TypeVar('ValueType')
@@ -12,17 +13,35 @@ MAX_DHT_TIME_DISCREPANCY_SECONDS = 3  # max allowed difference between get_dht_t
 DHTExpiration = float
 ROOT = 0
 
-
-class ValueWithExpiration(NamedTuple, Generic[ValueType]):
+@dataclass(init=True, repr=True, frozen=True)
+class ValueWithExpiration(Generic[ValueType]):
     value: ValueType
     expiration_time: DHTExpiration
 
+    def __iter__(self):
+        return iter((self.value, self.expiration_time))
+
+    def __getitem__(self, item):
+        if item == 0:
+            return self.value
+        elif item == 1:
+            return self.expiration_time
+        else:
+            return getattr(self, item)
+
+    def __eq__(self, item):
+        if isinstance(item, ValueWithExpiration):
+            return self.value == item.value and self.expiration_time == item.expiration_time
+        elif isinstance(item, tuple):
+            return tuple.__eq__((self.value, self.expiration_time), item)
+        else:
+            return False
 
-class HeapEntry(NamedTuple, Generic[KeyType]):
+@dataclass(init=True, repr=True, order=True, frozen=True)
+class HeapEntry(Generic[KeyType]):
     expiration_time: DHTExpiration
     key: KeyType
 
-
 class TimedStorage(Generic[KeyType, ValueType]):
     """ A dictionary that maintains up to :maxsize: key-value-expiration tuples until their expiration_time """
     frozen = False  # can be set to True. If true, do not remove outdated elements

+ 13 - 0
tests/test_util_modules.py

@@ -191,3 +191,16 @@ def test_serialize_tensor():
     assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
     restored = hivemind.combine_from_streaming(chunks)
     assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)
+
+def test_generic_data_classes():
+    from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration
+
+    value_with_exp = ValueWithExpiration(value="string_value", expiration_time=DHTExpiration(10))
+    assert value_with_exp.value == "string_value" and value_with_exp.expiration_time == DHTExpiration(10)
+
+    heap_entry = HeapEntry(expiration_time=DHTExpiration(10), key="string_value")
+    assert heap_entry.key == "string_value" and heap_entry.expiration_time == DHTExpiration(10)
+
+    sorted_expirations = sorted([DHTExpiration(value) for value in range(1, 1000)])
+    sorted_heap_entry = sorted([HeapEntry(expiration_time=DHTExpiration(value), key="any") for value in range(1, 1000)[::-1]])
+    assert all([heap_entry.expiration_time == value for heap_entry, value in zip(sorted_heap_entry, sorted_expirations)])