dropout.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch.autograd
  2. from torch import nn as nn
  3. from hivemind.moe.server.layers.custom_experts import register_expert_class
  4. class DeterministicDropoutFunction(torch.autograd.Function):
  5. @staticmethod
  6. def forward(ctx, x, keep_prob, mask):
  7. ctx.keep_prob = keep_prob
  8. ctx.save_for_backward(mask)
  9. return x * mask / keep_prob
  10. @staticmethod
  11. def backward(ctx, grad_output):
  12. return ctx.saved_tensors[0] * grad_output / ctx.keep_prob, None, None
  13. class DeterministicDropout(nn.Module):
  14. """
  15. Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations).
  16. Can be used with RemoteExpert/ExpertBackend to ensure that dropout mask is the same at forward and backward steps
  17. """
  18. def __init__(self, drop_prob):
  19. super().__init__()
  20. self.keep_prob = 1 - drop_prob
  21. def forward(self, x, mask):
  22. if self.training:
  23. return DeterministicDropoutFunction.apply(x, self.keep_prob, mask)
  24. else:
  25. return x
  26. dropout_sample_input = lambda batch_size, hid_dim: \
  27. (torch.empty((batch_size, hid_dim)), torch.randint(0, 1, (batch_size, hid_dim)))
  28. @register_expert_class('det_dropout', dropout_sample_input)
  29. class DeterministicDropoutNetwork(nn.Module):
  30. def __init__(self, hid_dim, dropout_prob=0.2):
  31. super().__init__()
  32. self.linear_in = nn.Linear(hid_dim, 2 * hid_dim)
  33. self.activation = nn.ReLU()
  34. self.dropout = DeterministicDropout(dropout_prob)
  35. self.linear_out = nn.Linear(2 * hid_dim, hid_dim)
  36. def forward(self, x, mask):
  37. x = self.linear_in(self.dropout(x, mask))
  38. return self.linear_out(self.activation(x))