@@ -12,7 +12,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
-from datasets import Dataset
+from torch.utils.data import Dataset
import hivemind
from hivemind.averaging.control import AveragingStage