Как было описано в лекции, одно из ключевых нововведений в Auto-scheduler -- появление правил вывода для генерации новых реализаций -- эскизов. Работа с эскизами была перенесена в MetaScheduler, на примере которого можно рассмотреть работу этого подхода. Рассмотрим работу на примере умножения матриц.
import tvm
from tvm import te, topi, tir
from tvm.ir.module import IRModule
from tvm import meta_schedule as ms
Определим матричное умножение аналогично тому, как делали это ранее. Далее создадим тензорное выражение, которое преобразуем в IRModule. Из тензорного выражения получим TIR, который можно передать в модуль Schedule и начать оптимизировать оператор.
def matmul(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
s = te.create_schedule(C.op)
return s, [A, B, C]
N, L, M = 512, 512, 512
s, (A, B, C) = matmul(N, L, M, 'float32')
func = te.create_prim_func([A, B, C])
tir_matmul = IRModule({"main": func})
print(tir_matmul.script())
# from tvm.script import ir as I # from tvm.script import tir as T @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((512, 512), "float32"), B: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for i, j, k in T.grid(512, 512, 512): with T.block("C"): v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) T.reads(A[v_i, v_k], B[v_k, v_j]) T.writes(C[v_i, v_j]) with T.init(): C[v_i, v_j] = T.float32(0) C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
Применим правила генерации блочного кода для трех циклов умножения матриц. Новая реализация будет содержать 10 циклов. Также MetaScheduler сразу произведет случайное аннотирование размера циклов.
get_block('C')
возвращает блок кода, внутри которого будет производиться оптимизация.MultiLevelTiling('SSRSRS')
-- применение правила для разбиения кода и аннотирования циклов.sch = tir.Schedule(func)
C = sch.get_block('C')
mlt_rule = ms.schedule_rule.MultiLevelTiling('SSRSRS')
sch = mlt_rule.apply(sch, C)[0]
sch.mod['main'].show()
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((512, 512), "float32"), B: T.Buffer((512, 512), "float32"), C: T.Buffer((512, 512), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 in T.grid(4, 2, 16, 1, 4, 8, 4, 128, 1, 64):
with T.block("C"):
v_i = T.axis.spatial(512, i_0 * 128 + i_1 * 8 + i_2 + i_3)
v_j = T.axis.spatial(512, j_0 * 256 + j_1 * 256 + j_2 * 64 + j_3)
v_k = T.axis.reduce(512, k_0 * 128 + k_1)
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(C[v_i, v_j])
T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
with T.init():
C[v_i, v_j] = T.float32(0)
C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]