|
@@ -1,7 +1,7 @@
|
|
|
import threading
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
-from shutil import copytree
|
|
|
+from shutil import copy2
|
|
|
from tempfile import TemporaryDirectory
|
|
|
from typing import Dict
|
|
|
import os
|
|
@@ -18,6 +18,18 @@ def dir_is_correct(directory: Path):
|
|
|
return True
|
|
|
|
|
|
|
|
|
+def copy_tree(src: str, dst: str):
|
|
|
+ if not os.path.exists(dst):
|
|
|
+ os.makedirs(dst)
|
|
|
+ for item in os.listdir(src):
|
|
|
+ src_entry = os.path.join(src, item)
|
|
|
+ dst_entry = os.path.join(dst, item)
|
|
|
+ if os.path.isdir(src_entry):
|
|
|
+ copy_tree(src_entry, dst_entry)
|
|
|
+ else:
|
|
|
+ copy2(src_entry, dst_entry)
|
|
|
+
|
|
|
+
|
|
|
class CheckpointSaver(threading.Thread):
|
|
|
def __init__(self, expert_backends: Dict[str, ExpertBackend], checkpoint_dir: Path, update_period: int):
|
|
|
super().__init__()
|
|
@@ -45,7 +57,7 @@ def store_experts(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
|
|
|
checkpoint_name = expert_dir / f'checkpoint_{timestamp}.pt'
|
|
|
torch.save(expert_backend.state_dict(), checkpoint_name)
|
|
|
os.symlink(checkpoint_name, expert_dir / 'checkpoint_last.pt')
|
|
|
- copytree(tmpdirname, str(checkpoint_dir), dirs_exist_ok=True)
|
|
|
+ copy_tree(tmpdirname, str(checkpoint_dir))
|
|
|
|
|
|
|
|
|
def load_weights(experts: Dict[str, ExpertBackend], checkpoint_dir: Path):
|