Bläddra i källkod

Remove unused imports, add missing arguments to docstrings (#108)

* Remove unused imports, add missing arguments to docstrings
Max Ryabinin 2 år sedan
förälder
incheckning
9faf08b898

+ 0 - 1
src/petals/bloom/model.py

@@ -21,7 +21,6 @@ from transformers.modeling_outputs import (
     CausalLMOutputWithCrossAttentions,
     SequenceClassifierOutputWithPast,
 )
-from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
 from transformers.utils import logging

+ 0 - 4
src/petals/bloom/ops.py

@@ -196,10 +196,6 @@ class BloomScaledSoftmax(nn.Module):
     fused operation: scaling + mask + softmax
 
     Args:
-        input_in_fp16 (`bool`, *required*):
-            flag to indicate if input in fp16 data format.
-        input_in_bf16 (`bool`, *required*):
-            flag to indicate if input in bf16 data format.
         scaled_masked_softmax_fusion (`bool`, *required*):
             flag to indicate user want to use softmax fusion
         mask_func (`function`, *required*):

+ 1 - 0
src/petals/client/remote_generation.py

@@ -57,6 +57,7 @@ class RemoteGenerationMixin:
         :param bos_token_id: The id of the beginning of sentence token.
         :param eos_token_id: The id of the end of sentence token.
         :param pad_token_id: The id of the padding token.
+        :param max_length: The maximum number of tokens in the output (including input tokens).
         :param max_new_tokens: The maximum number of tokens to generate.
         :param decoding_algorithm: The decoding algorithm to use.
         :param provided_constraints: A list of constraints to use.

+ 0 - 1
src/petals/client/sequential_autograd.py

@@ -51,7 +51,6 @@ async def sequential_forward(
     sequences = deque()
     intermediate_inputs = []
     done_sequences = []
-    outputs = inputs
 
     block_idx = start_index
     while block_idx < end_index:

+ 1 - 1
src/petals/server/backend.py

@@ -1,5 +1,5 @@
 """Code for serving bloom blocks via hivemind-server"""
-from typing import Any, Dict, Optional, Sequence, Tuple
+from typing import Any, Dict, Sequence, Tuple
 
 import torch
 from hivemind import BatchTensorDescriptor, use_hivemind_log_handler

+ 1 - 1
src/petals/server/handler.py

@@ -17,7 +17,7 @@ from hivemind import (
 from hivemind.moe.server.connection_handler import ConnectionHandler
 from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
+from hivemind.utils.asyncio import amap_in_executor, anext
 from hivemind.utils.logging import get_logger
 from hivemind.utils.streaming import split_for_streaming
 

+ 1 - 1
src/petals/server/task_pool.py

@@ -4,7 +4,7 @@ import threading
 import time
 from dataclasses import dataclass, field
 from queue import PriorityQueue
-from typing import Any, Generator, List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple
 
 import torch
 from hivemind import MPFuture, get_logger, use_hivemind_log_handler

+ 0 - 1
src/petals/server/task_prioritizer.py

@@ -1,7 +1,6 @@
 from abc import ABC, abstractmethod
 
 import torch
-from hivemind.moe.server.task_pool import Task
 
 
 class TaskPrioritizerBase(ABC):