|
@@ -33,7 +33,7 @@ from petals.server.memory_cache import MemoryCache
|
|
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
-from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
+from petals.utils.convert_block import QuantType, check_device_balance, convert_block, is_gptq_quant
|
|
|
from petals.utils.dht import declare_active_modules, get_remote_module_infos
|
|
|
from petals.utils.misc import get_size_in_bytes
|
|
|
from petals.utils.ping import PingAggregator
|
|
@@ -428,6 +428,8 @@ class Server:
|
|
|
self.dht.join()
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
class ModuleContainer(threading.Thread):
|
|
|
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
|
|
|
|
|
@@ -495,19 +497,20 @@ class ModuleContainer(threading.Thread):
|
|
|
cache_dir=cache_dir,
|
|
|
max_disk_space=max_disk_space,
|
|
|
)
|
|
|
- block = convert_block(
|
|
|
- block,
|
|
|
- block_index,
|
|
|
- block_config,
|
|
|
- tensor_parallel_devices,
|
|
|
- device,
|
|
|
- quant_type,
|
|
|
- adapters=server_info.adapters,
|
|
|
- freeze=True,
|
|
|
- token=token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- max_disk_space=max_disk_space,
|
|
|
- )
|
|
|
+ if not is_gptq_quant(block_config):
|
|
|
+ block = convert_block(
|
|
|
+ block,
|
|
|
+ block_index,
|
|
|
+ block_config,
|
|
|
+ tensor_parallel_devices,
|
|
|
+ device,
|
|
|
+ quant_type,
|
|
|
+ adapters=server_info.adapters,
|
|
|
+ freeze=True,
|
|
|
+ token=token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ max_disk_space=max_disk_space,
|
|
|
+ )
|
|
|
blocks[module_uid] = TransformerBackend(
|
|
|
module_uid,
|
|
|
block,
|
|
@@ -554,6 +557,7 @@ class ModuleContainer(threading.Thread):
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
dht: DHT,
|