|
@@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
|
|
|
See commit history for authorship.
|
|
|
"""
|
|
|
|
|
|
+import platform
|
|
|
+
|
|
|
import psutil
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import torch.utils.checkpoint
|
|
|
-from cpufeature import CPUFeature
|
|
|
from hivemind import get_logger
|
|
|
from torch import nn
|
|
|
from transformers import BloomConfig
|
|
@@ -29,9 +30,15 @@ class LMHead(nn.Module):
|
|
|
|
|
|
self.use_chunked_forward = config.use_chunked_forward
|
|
|
if self.use_chunked_forward == "auto":
|
|
|
- # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
|
|
- # Otherwise, it's ~8x slower.
|
|
|
- self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
|
|
+ if platform.machine() == "x86_64":
|
|
|
+ # Import of cpufeature may crash on non-x86_64 machines
|
|
|
+ from cpufeature import CPUFeature
|
|
|
+
|
|
|
+ # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
|
|
|
+ # Otherwise, it's ~8x slower.
|
|
|
+ self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
|
|
|
+ else:
|
|
|
+ self.use_chunked_forward = True
|
|
|
self.chunked_forward_step = config.chunked_forward_step
|
|
|
self._bf16_warning_shown = False
|
|
|
|