5
0
Эх сурвалжийг харах

make comments more polite ;)

justheuristic 3 жил өмнө
parent
commit
c316ca3606
1 өөрчлөгдсөн 4 нэмэгдсэн , 4 устгасан
  1. 4 4
      src/server/handler.py

+ 4 - 4
src/server/handler.py

@@ -137,12 +137,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
 
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
-        )  # TODO unfuck
+        )  # TODO generalize
 
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
@@ -163,11 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
-        )  # TODO unfuck
+        )  # TODO generalize
 
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [