瀏覽代碼

Fix use_chunked_forward="auto" on non-x86_64 machines (#267)

Import of cpufeature may crash on non-x86_64 machines, so this PR makes the client import it only if necessary.
Alexander Borzunov 2 年之前
父節點
當前提交
fd9400b392
共有 1 個文件被更改,包括 11 次插入4 次删除
  1. 11 4
      src/petals/bloom/modeling_utils.py

+ 11 - 4
src/petals/bloom/modeling_utils.py

@@ -4,11 +4,12 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
 See commit history for authorship.
 """
 
+import platform
+
 import psutil
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
-from cpufeature import CPUFeature
 from hivemind import get_logger
 from torch import nn
 from transformers import BloomConfig
@@ -29,9 +30,15 @@ class LMHead(nn.Module):
 
         self.use_chunked_forward = config.use_chunked_forward
         if self.use_chunked_forward == "auto":
-            # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
-            # Otherwise, it's ~8x slower.
-            self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
+            if platform.machine() == "x86_64":
+                # Import of cpufeature may crash on non-x86_64 machines
+                from cpufeature import CPUFeature
+
+                # If the CPU supports AVX512, plain bfloat16 is ~10x faster than chunked_forward().
+                # Otherwise, it's ~8x slower.
+                self.use_chunked_forward = not (CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"])
+            else:
+                self.use_chunked_forward = True
         self.chunked_forward_step = config.chunked_forward_step
         self._bf16_warning_shown = False