Explorar o código

Update benchmarks/benchmark_optimizer.py

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
bd93881fe4
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      benchmarks/benchmark_optimizer.py

+ 1 - 1
benchmarks/benchmark_optimizer.py

@@ -64,7 +64,7 @@ def benchmark_optimizer(args: TrainingArguments):
     dht = hivemind.DHT(start=True)
 
     train_dataset = args.make_dataset()
-    num_features = np.prod(train_dataset.data[0].shape)
+    num_features = train_dataset.data[0].numel()
     num_classes = len(train_dataset.classes)
     X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
     X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))