|
@@ -64,7 +64,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
dht = hivemind.DHT(start=True)
|
|
|
|
|
|
train_dataset = args.make_dataset()
|
|
|
- num_features = np.prod(train_dataset.data[0].shape)
|
|
|
+ num_features = train_dataset.data[0].numel()
|
|
|
num_classes = len(train_dataset.classes)
|
|
|
X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
|
|
|
X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
|