Quellcode durchsuchen

Fix dtypes in backend schemas

Aleksandr Borzunov vor 2 Jahren
Ursprung
Commit
e41b788fcc
1 geänderte Dateien mit 2 neuen und 3 gelöschten Zeilen
  1. 2 3
      src/server/server.py

+ 2 - 3
src/server/server.py

@@ -286,7 +286,6 @@ class ModuleContainer(threading.Thread):
                 )
                 )
 
 
                 if load_in_8bit:
                 if load_in_8bit:
-                    dtype = block.input_layernorm.weight.dtype
                     block = replace_8bit_linear(block)
                     block = replace_8bit_linear(block)
 
 
                 block = block.to(device)
                 block = block.to(device)
@@ -300,13 +299,13 @@ class ModuleContainer(threading.Thread):
                     backend_dtype=None if torch_dtype == "auto" else torch_dtype,
                     backend_dtype=None if torch_dtype == "auto" else torch_dtype,
                     args_schema=(
                     args_schema=(
                         BatchTensorDescriptor(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
                         ),
                         ),
                     ),
                     ),
                     kwargs_schema={},
                     kwargs_schema={},
                     outputs_schema=(
                     outputs_schema=(
                         BatchTensorDescriptor(
                         BatchTensorDescriptor(
-                            1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
+                            1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
                         ),
                         ),
                     ),
                     ),
                     min_batch_size=min_batch_size,
                     min_batch_size=min_batch_size,