|
@@ -12,7 +12,7 @@ import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torchvision
|
|
-from datasets import Dataset
|
|
|
|
|
|
+from torch.utils.data import Dataset
|
|
|
|
|
|
import hivemind
|
|
import hivemind
|
|
from hivemind.averaging.control import AveragingStage
|
|
from hivemind.averaging.control import AveragingStage
|