|
@@ -196,10 +196,6 @@ class BloomScaledSoftmax(nn.Module):
|
|
fused operation: scaling + mask + softmax
|
|
fused operation: scaling + mask + softmax
|
|
|
|
|
|
Args:
|
|
Args:
|
|
- input_in_fp16 (`bool`, *required*):
|
|
|
|
- flag to indicate if input in fp16 data format.
|
|
|
|
- input_in_bf16 (`bool`, *required*):
|
|
|
|
- flag to indicate if input in bf16 data format.
|
|
|
|
scaled_masked_softmax_fusion (`bool`, *required*):
|
|
scaled_masked_softmax_fusion (`bool`, *required*):
|
|
flag to indicate user want to use softmax fusion
|
|
flag to indicate user want to use softmax fusion
|
|
mask_func (`function`, *required*):
|
|
mask_func (`function`, *required*):
|