瀏覽代碼

make comments more polite ;)

justheuristic 3 年之前
父節點
當前提交
c316ca3606
共有 1 個文件被更改,包括 4 次插入4 次删除
  1. 4 4
      src/server/handler.py

+ 4 - 4
src/server/handler.py

@@ -137,12 +137,12 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
 
         # Modify grad_inputs_schema to support grad_prompts
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
 
 
         grad_inputs_schema_with_prompts = (
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
             requested_backends[0].kwargs_schema,
-        )  # TODO unfuck
+        )  # TODO generalize
 
 
         # Serialize the overall grad_input and respond
         # Serialize the overall grad_input and respond
         return runtime_pb2.ExpertResponse(
         return runtime_pb2.ExpertResponse(
@@ -163,11 +163,11 @@ class TransformerConnectionHandler(ConnectionHandler):
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
         grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends)
 
 
         # Modify grad_inputs_schema to support grad_prompts
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO unfuck
+        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
         grad_inputs_schema_with_prompts = (
         grad_inputs_schema_with_prompts = (
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].args_schema * len(grads),
             requested_backends[0].kwargs_schema,
             requested_backends[0].kwargs_schema,
-        )  # TODO unfuck
+        )  # TODO generalize
 
 
         # Serialize the overall grad_inputs
         # Serialize the overall grad_inputs
         serialized_grad_inputs = [
         serialized_grad_inputs = [