Browse Source

Fix format bugs on client side

Aleksandr Borzunov 3 năm trước cách đây
mục cha
commit
0336a04082
4 tập tin đã thay đổi với 20 bổ sung8 xóa
  1. 2 2
      README.md
  2. 1 1
      src/client/remote_block.py
  3. 3 2
      src/dht_utils.py
  4. 14 3
      src/server/server.py

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers 
+# - each server except first should have --initial_peers pointing to one of pre-existing servers
 ```
 
 Then open a python notebook or console and run:
@@ -66,7 +66,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
         super().__init__(peer_info, p2p)
 
     @property

+ 3 - 2
src/dht_utils.py

@@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
@@ -133,13 +133,14 @@ async def _get_remote_module_infos(
         servers = {}
         for peer_id, server_info in metadata.value.items():
             try:
+                peer_id = PeerID.from_base58(peer_id)
                 server_info = server_info.value
                 if not (isinstance(server_info, tuple) and len(server_info) == 2 and
                         isinstance(server_info[0], int) and isinstance(server_info[1], float)):
                     raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
                 state, throughput = server_info
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
-            except ValueError as e:
+            except (TypeError, ValueError) as e:
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)

+ 14 - 3
src/server/server.py

@@ -41,7 +41,8 @@ class Server(threading.Thread):
         **kwargs,
     ):
         threading.Thread.__init__(self)
-        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
@@ -165,8 +166,8 @@ class Server(threading.Thread):
             state=ServerState.JOINING,
             throughput=throughput,
         )
+        logger.info(f"Announced that blocks {block_indices} are joining")
 
-        logger.info(f"Loading blocks with indices {block_indices}")
         blocks = {}
         for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
@@ -232,6 +233,16 @@ class Server(threading.Thread):
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
+        if self.module_backends:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.OFFLINE,
+                throughput=self.throughput,
+            )
+            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
         self.ready.clear()
 
         for process in self.conn_handlers:
@@ -253,7 +264,7 @@ class Server(threading.Thread):
         logger.debug(f"Shutting down runtime")
 
         self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
+        logger.info("Server shut down succesfully")
 
 
 class ModuleAnnouncerThread(threading.Thread):