|
@@ -107,7 +107,7 @@ class CustomMatmulLtState(MatmulLtState):
|
|
|
"col_turing",
|
|
|
"col_ampere",
|
|
|
), f"please find this assert and manually enter tile size for {self.formatB}"
|
|
|
- return (8, 32) if self.formatB == "col_turing" else "col_ampere"
|
|
|
+ return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
|
|
|
|
|
|
|
|
def custom_matmul8bitlt(
|