Bladeren bron

Logging + got rid of batch_received (#52)

* Fix empty list of neighbors

* Add basic logger creation

* Introduce better logging, adjust tests a bit

* Add empty line

* Remove aiologger from requirements

* Revert test_dht

* Update tests

* More debug logging for backend
Max Ryabinin 5 jaren geleden
bovenliggende
commit
fe68aa1050

+ 1 - 1
hivemind/__init__.py

@@ -1,6 +1,6 @@
 from .client import *
 from .dht import *
-from .server import *
+from .server import Server
 from .utils import *
 from .runtime import *
 

+ 23 - 10
hivemind/dht/protocol.py

@@ -1,23 +1,30 @@
 from __future__ import annotations
-import os
-import heapq
+
 import asyncio
-import logging
+import heapq
+import os
 import urllib.parse
 from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
 from warnings import warn
+
+import grpc
+import grpc.experimental.aio
+
 from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
-from ..utils import Endpoint, compile_grpc
-import grpc, grpc.experimental.aio
+from ..utils import Endpoint, compile_grpc, get_logger
+
+logger = get_logger(__name__)
 
 with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
     dht_pb2, dht_grpc = compile_grpc(f_proto.read())
 
 
 class DHTProtocol(dht_grpc.DHTServicer):
+    # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
     storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
+    # fmt:on
 
     @classmethod
     async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -87,7 +94,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         try:
             peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
             peer_info = None
         responded = bool(peer_info and peer_info.node_id)
         peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
@@ -134,7 +141,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
             return [False] * len(keys)
 
@@ -180,7 +187,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 output[key] = (value, expiration, nearest)
             return output
         except grpc.experimental.aio.AioRpcError as error:
-            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
@@ -197,8 +204,13 @@ class DHTProtocol(dht_grpc.DHTServicer):
             cached_value, cached_expiration = self.cache.get(key_id)
             if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
                 maybe_value, maybe_expiration = cached_value, cached_expiration
-            peer_ids, endpoints = zip(*self.routing_table.get_nearest_neighbors(
-                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)))
+
+            nearest_neighbors = self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
+            if nearest_neighbors:
+                peer_ids, endpoints = zip(*nearest_neighbors)
+            else:
+                peer_ids, endpoints = [], []
 
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
             response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
@@ -247,6 +259,7 @@ _NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values
 
 class LocalStorage:
     """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
+
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.data = dict()

+ 12 - 3
hivemind/runtime/__init__.py

@@ -10,6 +10,9 @@ from prefetch_generator import BackgroundGenerator
 
 from .expert_backend import ExpertBackend
 from .task_pool import TaskPool, TaskPoolBase
+from hivemind.utils import get_logger
+
+logger = get_logger(__name__)
 
 
 class Runtime(threading.Thread):
@@ -34,6 +37,7 @@ class Runtime(threading.Thread):
     :param device: if specified, moves all experts and data to this device via .to(device=device).
       If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
     """
+
     def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1,
                  device: torch.device = None):
         super().__init__()
@@ -44,7 +48,6 @@ class Runtime(threading.Thread):
         self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches
 
     def run(self):
-        progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}')
         for pool in self.pools:
             if not pool.is_alive():
                 pool.start()
@@ -55,13 +58,15 @@ class Runtime(threading.Thread):
         with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
             try:
                 self.ready.set()
+                logger.info("Started")
                 for pool, batch_index, batch in BackgroundGenerator(
                         self.iterate_minibatches_from_pools(), self.prefetch_batches):
+                    logger.debug(f"Processing batch {batch_index} from pool {pool.uid}")
                     outputs = pool.process_func(*batch)
+                    logger.info(f"Pool {pool.uid}: batch {batch_index} processed, size {outputs[0].size(0)}")
                     output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
-                    progress.update(len(outputs[0]))
-                    progress.desc = f'pool.uid={pool.uid} batch_size={len(outputs[0])}'
             finally:
+                logger.info("Shutting down")
                 self.shutdown()
 
     SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
@@ -85,12 +90,16 @@ class Runtime(threading.Thread):
 
             while True:
                 # wait until at least one batch_receiver becomes available
+                logger.debug("Waiting for inputs from task pools")
                 ready_fds = selector.select()
                 ready_objects = {key.data for (key, events) in ready_fds}
                 if self.SHUTDOWN_TRIGGER in ready_objects:
                     break  # someone asked us to shutdown, break from the loop
 
+                logger.debug("Choosing the pool with highest priority")
                 pool = max(ready_objects, key=lambda pool: pool.priority)
 
+                logger.debug(f"Loading batch from {pool.uid}")
                 batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
+                logger.debug(f"Loaded batch from {pool.uid}")
                 yield pool, batch_index, batch_tensors

+ 14 - 13
hivemind/runtime/task_pool.py

@@ -14,8 +14,9 @@ from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
-from ..utils import SharedFuture
+from hivemind.utils import SharedFuture, get_logger
 
+logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 
 
@@ -78,7 +79,6 @@ class TaskPool(TaskPoolBase):
 
         # interaction with Runtime
         self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain batch inputs
-        self.batch_received = mp.Event()  # runtime can notify pool that it can send next batch
         self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False)  # send/recv arrays that contain outputs
 
         if start:
@@ -107,12 +107,11 @@ class TaskPool(TaskPoolBase):
                 batch = []
                 total_size = 0
             try:
+                logger.debug(f"{self.uid} getting next task")
                 task = self.tasks.get(timeout=self.timeout)
             except Empty:
-                exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.")
-                for task in batch:
-                    task.future.set_exception(exc)
-                raise exc
+                logger.warning(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet")
+                continue
 
             task_size = self.get_task_size(task)
 
@@ -126,10 +125,10 @@ class TaskPool(TaskPoolBase):
                 total_size += task_size
 
     def run(self, *args, **kwargs):
-        print(f'Starting pool, pid={os.getpid()}')
+        logger.info(f'{self.uid} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],
-                                         name=f'{self.uid}-pool_output_loop')
+                                         name=f'{self.uid}_output')
         try:
             output_thread.start()
             self._pool_input_loop(pending_batches, *args, **kwargs)
@@ -144,11 +143,8 @@ class TaskPool(TaskPoolBase):
         prev_num_tasks = 0  # number of tasks currently in shared buffer
         batch_index = max(pending_batches.keys(), default=0)
         batch_iterator = self.iterate_minibatches(*args, **kwargs)
-        self.batch_received.set()  # initial state: no batches/outputs pending
 
         while True:
-            self.batch_received.wait()  # wait for runtime to receive (copy) previous batch
-
             # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task
             # assumes that tasks are processed in the same order as they are created
             for skip_i in range(prev_num_tasks):
@@ -156,18 +152,21 @@ class TaskPool(TaskPoolBase):
                 if skip_i == prev_num_tasks - 1:
                     self.priority = finished_task_timestamp
 
+            logger.debug(f"{self.uid} getting next batch")
             batch_tasks = next(batch_iterator)
             # save batch futures, _output_loop will deliver on them later
             pending_batches[batch_index] = batch_tasks
 
+            logger.debug(f"{self.uid}, batch  {batch_index}: aggregating inputs")
             # find or create shared arrays for current batch size
             batch_inputs = [
                 torch.cat([task.args[i] for task in batch_tasks]).share_memory_()
                 for i in range(len(batch_tasks[0].args))
             ]
 
-            self.batch_received.clear()  # sending next batch...
+            logger.debug(f"{self.uid}, batch {batch_index}: sending to runtime")
             self.batch_sender.send((batch_index, batch_inputs))
+            logger.debug(f"{self.uid}, batch {batch_index}: sent to runtime")
             prev_num_tasks = len(batch_tasks)
             batch_index += 1
 
@@ -175,16 +174,19 @@ class TaskPool(TaskPoolBase):
         """ Infinite loop: receive results from runtime and dispatch them to task Futures """
 
         while True:
+            logger.debug(f"{self.uid} waiting for results from runtime")
             payload = self.outputs_receiver.recv()
             if isinstance(payload, BaseException):
                 raise payload
             else:
                 batch_index, batch_outputs = payload
+            logger.debug(f"{self.uid}, batch {batch_index}: got results")
 
             # split batch into partitions for individual tasks
             batch_tasks = pending_batches.pop(batch_index)
             task_sizes = [self.get_task_size(task) for task in batch_tasks]
             outputs_per_task = zip(*(torch.split_with_sizes(array, task_sizes, dim=0) for array in batch_outputs))
+            logger.debug(f"{self.uid}, batch {batch_index}: sending outputs to handlers")
 
             # dispatch results to futures
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
@@ -200,7 +202,6 @@ class TaskPool(TaskPoolBase):
             raise TimeoutError()
 
         batch_index, batch_inputs = self.batch_receiver.recv()
-        self.batch_received.set()  # pool can now prepare next batch
         batch_inputs = [tensor.to(device, non_blocking=True) for tensor in batch_inputs]
         return batch_index, batch_inputs
 

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,3 +7,4 @@ from .shared_future import *
 from .threading import *
 from .autograd import *
 from .grpc import *
+from .logging import get_logger

+ 18 - 0
hivemind/utils/logging.py

@@ -0,0 +1,18 @@
+import logging
+import os
+
+
+def get_logger(module_name: str) -> logging.Logger:
+    # trim package name
+    name_without_prefix = '.'.join(module_name.split('.')[1:])
+    loglevel = os.getenv('LOGLEVEL', 'INFO')
+
+    logging.addLevelName(logging.WARNING, 'WARN')
+    formatter = logging.Formatter(fmt='[{asctime}.{msecs:03.0f}][{levelname}][{name}.{funcName}:{lineno}] {message}', style='{',
+                                  datefmt='%Y/%m/%d %H:%M:%S')
+    handler = logging.StreamHandler()
+    handler.setFormatter(formatter)
+    logger = logging.getLogger(name_without_prefix)
+    logger.setLevel(loglevel)
+    logger.addHandler(handler)
+    return logger

+ 1 - 2
requirements.txt

@@ -7,5 +7,4 @@ prefetch_generator>=1.0.1
 pytest
 umsgpack
 grpcio
-grpcio-tools>=1.30.0
-aiologger>=0.5.0
+grpcio-tools>=1.30.0

+ 8 - 7
tests/test_dht.py

@@ -31,7 +31,7 @@ def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
-def test_kademlia_protocol():
+def test_dht_protocol():
     # create the first peer
     peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
     peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
@@ -72,7 +72,7 @@ def test_kademlia_protocol():
                 f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
 
             assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+                                                                          f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
 
             # peer 2 must know about peer 1, but not have a *random* nonexistent value
             dummy_key = DHTID.generate()
@@ -250,15 +250,15 @@ def test_hivemind_dht():
 
 def test_store():
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 10)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
     assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
 
 
 def test_get_expired():
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val", get_dht_time() + 1)
-    time.sleep(2)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
+    time.sleep(0.5)
     assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     print("Test get expired passed")
 
@@ -271,9 +271,10 @@ def test_get_empty():
 
 def test_change_expiration_time():
     d = LocalStorage()
-    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 2)
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
+    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
     d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
-    time.sleep(4)
+    time.sleep(1)
     assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
     print("Test change expiration time passed")
 

+ 2 - 2
tests/test_moe.py

@@ -18,7 +18,7 @@ def test_remote_module_call():
     logits = torch.randn(3, requires_grad=True)
     random_proj = torch.randn_like(xx)
 
-    with background_server(num_experts=num_experts, device='cpu',
+    with background_server(num_experts=num_experts, device='cpu', num_handlers=1,
                            no_optimizer=True, no_dht=True) as (localhost, server_port, dht_port):
         experts = [hivemind.RemoteExpert(uid=f'expert.{i}', port=server_port) for i in range(num_experts)]
         moe_output, = hivemind.client.moe._RemoteMoECall.apply(
@@ -50,7 +50,7 @@ def test_determinism():
     xx = torch.randn(32, 1024, requires_grad=True)
     mask = torch.randint(0, 1, (32, 1024))
 
-    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout',
+    with background_server(num_experts=1, device='cpu', expert_cls='det_dropout', num_handlers=1,
                            no_optimizer=True, no_dht=True) as (interface, server_port, dht_port):
         expert = hivemind.RemoteExpert(uid=f'expert.0', port=server_port)