|
@@ -3,10 +3,12 @@ import os
|
|
|
import threading
|
|
|
from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
|
|
|
from typing import Dict, Optional
|
|
|
+
|
|
|
import torch
|
|
|
|
|
|
from .connection_handler import handle_connection
|
|
|
from .dht_handler import DHTHandlerThread
|
|
|
+from .checkpoint_saver import CheckpointSaver
|
|
|
from ..dht import DHT
|
|
|
from ..runtime import Runtime, ExpertBackend
|
|
|
|
|
@@ -36,11 +38,15 @@ class Server(threading.Thread):
|
|
|
|
|
|
def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
|
|
|
port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
|
|
|
- **kwargs):
|
|
|
+ checkpoint_dir=None, **kwargs):
|
|
|
super().__init__()
|
|
|
self.dht, self.experts, self.update_period = dht, expert_backends, update_period
|
|
|
self.addr, self.port = addr, port
|
|
|
self.conn_handlers = self._create_connection_handlers(conn_handler_processes)
|
|
|
+ if checkpoint_dir is not None:
|
|
|
+ self.checkpoint_saver = CheckpointSaver(expert_backends, checkpoint_dir, update_period)
|
|
|
+ else:
|
|
|
+ self.checkpoint_saver = None
|
|
|
self.runtime = Runtime(self.experts, **kwargs)
|
|
|
|
|
|
if start:
|
|
@@ -58,6 +64,8 @@ class Server(threading.Thread):
|
|
|
dht_handler_thread = DHTHandlerThread(experts=self.experts, dht=self.dht,
|
|
|
addr=self.addr, port=self.port, update_period=self.update_period)
|
|
|
dht_handler_thread.start()
|
|
|
+ if self.checkpoint_saver is not None:
|
|
|
+ self.checkpoint_saver.start()
|
|
|
|
|
|
for process in self.conn_handlers:
|
|
|
if not process.is_alive():
|
|
@@ -68,7 +76,11 @@ class Server(threading.Thread):
|
|
|
for process in self.conn_handlers:
|
|
|
process.join()
|
|
|
if self.dht:
|
|
|
+ dht_handler_thread.stop = True
|
|
|
dht_handler_thread.join()
|
|
|
+ if self.checkpoint_saver is not None:
|
|
|
+ self.checkpoint_saver.stop = True
|
|
|
+ self.checkpoint_saver.join()
|
|
|
|
|
|
def run_in_background(self, await_ready=True, timeout=None):
|
|
|
"""
|
|
@@ -96,12 +108,12 @@ class Server(threading.Thread):
|
|
|
sock = socket(AF_INET, SOCK_STREAM)
|
|
|
sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
|
|
|
sock.bind(('', self.port))
|
|
|
- sock.listen()
|
|
|
+ sock.listen(1024)
|
|
|
sock.settimeout(self.update_period)
|
|
|
|
|
|
processes = [mp.context.ForkProcess(
|
|
|
target=socket_loop, name=f"socket_loop-{i}", args=(sock, self.experts), daemon=True)
|
|
|
- for i in range(num_handlers)]
|
|
|
+ for i in range(num_handlers)]
|
|
|
return processes
|
|
|
|
|
|
def shutdown(self):
|