瀏覽代碼

Remove unused imports and attributes (#324)

* Remove unused imports and attributes
Max Ryabinin 2 年之前
父節點
當前提交
3e7ae5116d

+ 0 - 1
src/petals/client/inference_session.py

@@ -2,7 +2,6 @@ from __future__ import annotations
 
 import asyncio
 import itertools
-import logging
 import time
 from typing import AsyncIterator, List, Optional
 

+ 1 - 2
src/petals/client/remote_model.py

@@ -1,6 +1,5 @@
-import os
 from contextlib import contextmanager
-from typing import Collection, List, Optional, Union
+from typing import List, Optional, Union
 
 import hivemind
 import torch

+ 0 - 1
src/petals/client/remote_sequential.py

@@ -4,7 +4,6 @@ from typing import Optional, Union
 
 import torch
 from hivemind import DHT, get_logger
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
 import petals.client

+ 1 - 1
src/petals/client/routing/sequence_manager.py

@@ -11,7 +11,7 @@ from typing import Any, Collection, Dict, List, Optional, Sequence, Union
 from weakref import WeakMethod
 
 import numpy as np
-from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
+from hivemind import DHT, P2P, MSGPackSerializer, PeerID
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2

+ 0 - 1
src/petals/client/sequential_autograd.py

@@ -3,7 +3,6 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 import asyncio
 import itertools
-import logging
 from collections import deque
 from typing import List, Optional, Sequence, Tuple
 

+ 0 - 2
src/petals/dht_utils.py

@@ -8,11 +8,9 @@ from functools import partial
 from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
 
-import petals.client
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 logger = get_logger(__name__)

+ 1 - 1
src/petals/server/backend.py

@@ -16,7 +16,7 @@ from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomAttention
 
 from petals.data_structures import InferenceMetadata
-from petals.server.memory_cache import Handle, MemoryCache
+from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
 

+ 2 - 2
src/petals/server/memory_cache.py

@@ -10,7 +10,7 @@ import ctypes
 import multiprocessing as mp
 import os
 import time
-from typing import AsyncContextManager, Dict, Optional, Sequence, Tuple
+from typing import AsyncContextManager, Dict, Optional, Sequence
 
 import hivemind
 import torch
@@ -29,7 +29,7 @@ class MemoryCache:
     def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
         self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
         self.alloc_timeout = alloc_timeout
-        self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
+        self._lock_metadata = mp.Lock()
         self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
         self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
         self._allocated_tensors: Dict[Handle, torch.Tensor] = {}

+ 0 - 1
src/petals/server/reachability.py

@@ -5,7 +5,6 @@ import time
 from concurrent.futures import Future
 from contextlib import asynccontextmanager
 from functools import partial
-from secrets import token_hex
 from typing import Optional
 
 import requests

+ 0 - 5
src/petals/server/server.py

@@ -8,7 +8,6 @@ import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
 
-import numpy as np
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -502,7 +501,6 @@ class ModuleContainer(threading.Thread):
             expiration=expiration,
             daemon=True,
         )
-        self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
         if start:
             self.run_in_background(await_ready=True)
@@ -517,9 +515,6 @@ class ModuleContainer(threading.Thread):
 
         self.online_announcer.start()
 
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.start()
-
         for handler in self.conn_handlers:
             handler.run_in_background()
 

+ 0 - 1
src/petals/utils/generation_algorithms.py

@@ -85,7 +85,6 @@ class NucleusAlgorithm(SamplingAlgorithm):
 class BeamSearchAlgorithm(DecodingAlgorithm):
     def __init__(self, num_beams: int, batch_size: int) -> None:
         self.num_beams = num_beams
-        self._cur_num_beams = 1
         self.batch_size = batch_size
 
         self._batch_beams = [list() for _ in range(batch_size)]

+ 0 - 1
src/petals/utils/logging.py

@@ -1,4 +1,3 @@
-import importlib
 import os
 
 from hivemind.utils import logging as hm_logging

+ 0 - 1
tests/test_block_exact_match.py

@@ -8,7 +8,6 @@ from transformers.models.bloom.configuration_bloom import BloomConfig
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
 from petals.client import DistributedBloomConfig, RemoteSequential
-from petals.data_structures import UID_DELIMITER
 from test_utils import *
 
 

+ 0 - 1
tests/test_server_stats.py

@@ -5,7 +5,6 @@ import pytest
 import torch
 
 from petals.client import DistributedBloomConfig, RemoteSequential
-from petals.data_structures import UID_DELIMITER
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *