Bläddra i källkod

rm prefix from tests

justheuristic 3 år sedan
förälder
incheckning
efa0d16714
2 ändrade filer med 7 tillägg och 1 borttagningar
  1. 6 0
      src/server/handler.py
  2. 1 1
      tests/test_full_model.py

+ 6 - 0
src/server/handler.py

@@ -71,6 +71,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             print("CLOSED RPC_INFERENCE")
 
     async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        return await super().rpc_forward(request, context)
         # Parse request and prepare backends
         hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_header(request)
@@ -96,6 +97,8 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_forward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
+        async for response in super().rpc_forward_stream(requests, context):
+            yield response
         # Parse requests and prepare backends
         uids_header, hidden_states = await self._gather_inputs(requests, context)
         requested_uids = self._check_header_str(uids_header)
@@ -124,6 +127,7 @@ class TransformerConnectionHandler(ConnectionHandler):
             yield runtime_pb2.ExpertResponse(tensors=[part])
 
     async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+        return await super().rpc_backward(request, context)
         # Parse requests and prepare backends
         inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         requested_uids = self._check_header(request)
@@ -157,6 +161,8 @@ class TransformerConnectionHandler(ConnectionHandler):
     async def rpc_backward_stream(
         self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
+        async for response in super().rpc_backward_stream(requests, context):
+            yield response
         uids_header, inputs_and_grads = await self._gather_inputs(requests, context)
         inputs, grads = inputs_and_grads
         requested_uids = self._check_header_str(uids_header)

+ 1 - 1
tests/test_full_model.py

@@ -24,7 +24,7 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 
 
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     assert len(model.transformer.h) == model.config.n_layer