瀏覽代碼

support various dtype for inference

dbaranchuk 3 年之前
父節點
當前提交
6165aca545
共有 2 個文件被更改,包括 6 次插入3 次删除
  1. 1 1
      src/client/inference_session.py
  2. 5 2
      src/server/handler.py

+ 1 - 1
src/client/inference_session.py

@@ -70,7 +70,7 @@ class RemoteTransformerBlockInferenceSession:
                 runtime_pb2.ExpertRequest(
                     uid=self.uid,
                     tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
+                        serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
                         for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
                     ],
                 )

+ 5 - 2
src/server/handler.py

@@ -48,6 +48,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                 while request.tensors:  # iterate while user is willing to supply tensors
                     hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
 
+                    # Cast inputs to backend dtype
+                    hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
+
                     # run request tensors through all requested modules, update caches
                     for backend, cache_handle in zip(requested_backends, cache_handles):
                         cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
@@ -62,7 +65,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     # serialize and send last layer outputs
                     yield runtime_pb2.ExpertResponse(
                         tensors=[
-                            serialize_torch_tensor(result, proto.compression, allow_inplace=True)
+                            serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
                             for result, proto in zip(
                                 hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
                             )
@@ -242,7 +245,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 head_dim = backend.module.self_attention.head_dim
 
                 cache_descriptor = TensorDescriptor(
-                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32
+                    size=(2, batch_size, MAX_LENGTH, num_heads, head_dim), dtype=backend.dtype
                 )
                 # [key_or_value, batch_size, max_length, num_heads, head_dim]