|
@@ -9,7 +9,9 @@ import time
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
|
|
+import psutil
|
|
import torch
|
|
import torch
|
|
|
|
+import torch.mps
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
from hivemind.moe.server.runtime import Runtime
|
|
from hivemind.moe.server.runtime import Runtime
|
|
@@ -154,13 +156,25 @@ class Server:
|
|
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
|
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
|
|
|
|
|
if device is None:
|
|
if device is None:
|
|
- device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
+ device = "cuda"
|
|
|
|
+ elif torch.backends.mps.is_available():
|
|
|
|
+ device = "mps"
|
|
|
|
+ else:
|
|
|
|
+ device = "cpu"
|
|
device = torch.device(device)
|
|
device = torch.device(device)
|
|
if device.type == "cuda" and device.index is None:
|
|
if device.type == "cuda" and device.index is None:
|
|
device = torch.device(device.type, index=0)
|
|
device = torch.device(device.type, index=0)
|
|
self.device = device
|
|
self.device = device
|
|
|
|
|
|
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
|
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
|
|
|
+ if device.type == "cpu" and torch_dtype == torch.float16:
|
|
|
|
+ raise ValueError(
|
|
|
|
+ f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
|
|
|
|
+ )
|
|
|
|
+ if device.type == "mps" and torch_dtype == torch.bfloat16:
|
|
|
|
+ logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
|
|
|
|
+ torch_dtype = torch.float16
|
|
self.torch_dtype = torch_dtype
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
if tensor_parallel_devices is None:
|
|
if tensor_parallel_devices is None:
|
|
@@ -253,13 +267,14 @@ class Server:
|
|
self.stop = threading.Event()
|
|
self.stop = threading.Event()
|
|
|
|
|
|
def _choose_num_blocks(self) -> int:
|
|
def _choose_num_blocks(self) -> int:
|
|
- assert self.device.type == "cuda", (
|
|
|
|
|
|
+ assert self.device.type in ("cuda", "mps"), (
|
|
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
|
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
|
"CPU-only servers in the public swarm are discouraged since they are much slower"
|
|
"CPU-only servers in the public swarm are discouraged since they are much slower"
|
|
)
|
|
)
|
|
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
|
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
|
|
|
|
|
if num_devices > 1:
|
|
if num_devices > 1:
|
|
|
|
+ assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
|
|
memory_per_device = tuple(
|
|
memory_per_device = tuple(
|
|
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
|
|
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
|
|
)
|
|
)
|
|
@@ -270,8 +285,10 @@ class Server:
|
|
"Please launch individual servers on each GPU or set --num_blocks manually to "
|
|
"Please launch individual servers on each GPU or set --num_blocks manually to "
|
|
"override this exception."
|
|
"override this exception."
|
|
)
|
|
)
|
|
- else:
|
|
|
|
|
|
+ elif self.device.type == "cuda":
|
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
|
+ else:
|
|
|
|
+ total_memory = psutil.virtual_memory().total
|
|
|
|
|
|
gib = 1024**3
|
|
gib = 1024**3
|
|
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
|
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
|
@@ -373,6 +390,8 @@ class Server:
|
|
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
|
|
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
|
|
f"{reserved_vram / gib:.1f} GiB reserved memory"
|
|
f"{reserved_vram / gib:.1f} GiB reserved memory"
|
|
)
|
|
)
|
|
|
|
+ elif self.device.type == "mps":
|
|
|
|
+ torch.mps.empty_cache()
|
|
|
|
|
|
def _choose_blocks(self) -> List[int]:
|
|
def _choose_blocks(self) -> List[int]:
|
|
if self.strict_block_indices is not None:
|
|
if self.strict_block_indices is not None:
|