Browse Source

fix reference

justheuristic 3 years ago
parent
commit
217f109723
2 changed files with 8 additions and 6 deletions
  1. 8 5
      src/client/remote_block.py
  2. 0 1
      src/client/remote_sequential.py

+ 8 - 5
src/client/remote_block.py

@@ -57,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.stepped = False
         self.closed = False
 
     @classmethod
@@ -102,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
         return await anext(self._outputs_stream)
 
     def close(self):
@@ -116,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
         """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
         if self._outputs_stream is None:
             return  # already closed
-        await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-        try:
-            await anext(self._outputs_stream)
-        except StopAsyncIteration:
-            pass
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
 
     def __del__(self):
         self.close()

+ 0 - 1
src/client/remote_sequential.py

@@ -75,7 +75,6 @@ class RemoteSequential(nn.Sequential):
         return RemoteSequentialInferenceSession(self.remote_sequence_info)
 
 
-
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""