فهرست منبع

Update benchmarks/benchmark_optimizer.py

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 سال پیش
والد
کامیت
bd93881fe4
1فایلهای تغییر یافته به همراه1 افزوده شده و 1 حذف شده
  1. 1 1
      benchmarks/benchmark_optimizer.py

+ 1 - 1
benchmarks/benchmark_optimizer.py

@@ -64,7 +64,7 @@ def benchmark_optimizer(args: TrainingArguments):
     dht = hivemind.DHT(start=True)
 
     train_dataset = args.make_dataset()
-    num_features = np.prod(train_dataset.data[0].shape)
+    num_features = train_dataset.data[0].numel()
     num_classes = len(train_dataset.classes)
     X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
     X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))