Skip to content

Commit

Permalink
Make batch matrix multiplication on GPU tunable
Browse files Browse the repository at this point in the history
This is primarily aimed at the AMD GPU backend and done as part
of a project for AMD, but should work for all users of the GPU
schedule.
  • Loading branch information
t-vi committed Jun 9, 2020
1 parent 6ae439c commit 8999bea
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,9 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
name="batch_matmul.gpu",
plevel=10)
if target.target_name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
Expand Down
75 changes: 52 additions & 23 deletions topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
"""cuda batch_matmul operators"""
import tvm
from tvm import autotvm
from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

def schedule_batch_matmul(outs):
@autotvm.register_topi_compute("batch_matmul.gpu")
def batch_matmul(cfg, x, y):
"""Compute conv2d with NCHW layout"""
return nn.batch_matmul(x, y)


@autotvm.register_topi_schedule("batch_matmul.gpu")
def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul
Parameters
Expand All @@ -37,7 +48,7 @@ def schedule_batch_matmul(outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(op):
def _schedule(cfg, op):
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
Expand All @@ -51,55 +62,73 @@ def _schedule(op):
C = s.outputs[0].output(0)

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
by, y = s[C].split(y, y_bn)
bx, x = s[C].split(x, x_bn)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
ty, yi = s[C].split(y, nparts=y_nthreads)
tx, xi = s[C].split(x, nparts=x_nthreads)
thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")

cfg.define_split("tile_y", y, num_outputs=3)
cfg.define_split("tile_x", x, num_outputs=3)
cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']:
# llvm-based backends cannot do non-explicit unrolling
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
cfg['auto_unroll_max_step'] = OtherOptionEntity(16)

by, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, tx, xi = cfg["tile_x"].apply(s, C, x)

thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")

s[C].reorder(b, by, bx, ty, tx, yi, xi)
s[C].bind(b, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].pragma(yi, "auto_unroll_max_step", 16)
s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)

s[CC].compute_at(s[C], tx)
_, yi, xi = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, 8)
cfg.define_split("tile_k", k, num_outputs=2)
if cfg.is_fallback:
cfg['tile_k'] = SplitEntity([-1, 8])
ko, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, ki, yi, xi)
s[CC].pragma(ki, "auto_unroll_max_step", 16)
s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)

s[AA].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], ki)
s[BB].compute_at(s[CC], ko)
s[BL].compute_at(s[CC], ki)
_, y, k = s[AA].op.axis
ty, yi = s[AA].split(y, nparts=y_nthreads)
tx, ki = s[AA].split(k, nparts=x_nthreads)
ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1])
tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1])
s[AA].reorder(ty, tx, yi, ki)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].pragma(yi, "auto_unroll_max_step", 16)
s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)

_, x, k = s[BB].op.axis
ty, xi = s[BB].split(x, nparts=y_nthreads)
tx, ki = s[BB].split(k, nparts=x_nthreads)
ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1])
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].reorder(ty, tx, xi, ki)
s[BB].pragma(xi, "auto_unroll_max_step", 16)
s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)

def _callback(op):
if "batch_matmul" in op.tag:
_schedule(op)
_schedule(cfg, op)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_batch_matmul_implement = {
"generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul),
"cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul),
"gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul),
"gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul),
}

def verify_batch_matmul(batch, M, N, K):
Expand Down

0 comments on commit 8999bea

Please sign in to comment.