Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make batch matrix multiplication on GPU tunable #5752

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10)
Expand Down
80 changes: 56 additions & 24 deletions topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
"""cuda batch_matmul operators"""
import tvm
from tvm import autotvm
from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

def schedule_batch_matmul(outs):
@autotvm.register_topi_compute("batch_matmul.cuda")
def batch_matmul(cfg, x, y):
"""Compute conv2d with NCHW layout"""
return nn.batch_matmul(x, y)


@autotvm.register_topi_schedule("batch_matmul.cuda")
def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul

Parameters
Expand All @@ -37,7 +48,7 @@ def schedule_batch_matmul(outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(op):
def _schedule(cfg, op):
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
Expand All @@ -51,55 +62,76 @@ def _schedule(op):
C = s.outputs[0].output(0)

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
by, y = s[C].split(y, y_bn)
bx, x = s[C].split(x, x_bn)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
ty, yi = s[C].split(y, nparts=y_nthreads)
tx, xi = s[C].split(x, nparts=x_nthreads)
thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
k, = s[CC].op.reduce_axis

cfg.define_split("tile_y", y, num_outputs=3)
cfg.define_split("tile_x", x, num_outputs=3)
cfg.define_split("tile_k", k, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']:
t-vi marked this conversation as resolved.
Show resolved Hide resolved
# llvm-based backends cannot do non-explicit unrolling
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing unroll_explicit in fallback schedule. Based on your comment, this might result in a problem at LLVM-based backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon looking at this, I should use unroll explicit, not just define it.
What I'm unsure about is whether I need to define it in the fallback - conv2d_direct doesn't define it and it seems to work well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it's interesting. It means this parameter is never effective in those schedules...others like conv2d_direct uses it like:

s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)

I think you can also fix them if you prefer to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be OK now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just found that the unroll_explicit is still missing in fallback. Also would you mind moving L103 (define tile_k) up together with other tuning parameters?

Copy link
Contributor Author

@t-vi t-vi Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As much fun as geeking out over this is
I don't think I need the fallback because

  • This schedule works without,
  • other schedules work without,
  • define_knob sets the fallback to the first option by setting cfg._entity_map which is what is queried by __getitem__.

I cannot move the definition of tile_k to before k is defined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first two points matter, we should just remove that parameter to make the tuning space more efficient. Meanwhile, I appreciate the third point that I didn't realize before.

For tile_k, I think you should be able to move the definition of k up as well. It should be safe because the tuning parameter must be static so it won't depend on other parameters. In this way, we can also put the fallback configs together to make it clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the k and tile_k it appears to work and looks much prettier indeed. Thank you for insisting on it.

Copy link
Contributor Author

@t-vi t-vi Jun 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else can I do to make it prettier?

y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
cfg['tile_k'] = SplitEntity([-1, 8])
cfg['auto_unroll_max_step'] = OtherOptionEntity(16)

by, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, tx, xi = cfg["tile_x"].apply(s, C, x)

thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")

s[C].reorder(b, by, bx, ty, tx, yi, xi)
s[C].bind(b, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].pragma(yi, "auto_unroll_max_step", 16)
s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)

s[CC].compute_at(s[C], tx)
_, yi, xi = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, 8)
ko, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, ki, yi, xi)
s[CC].pragma(ki, "auto_unroll_max_step", 16)
s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val)

s[AA].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], ki)
s[BB].compute_at(s[CC], ko)
s[BL].compute_at(s[CC], ki)
_, y, k = s[AA].op.axis
ty, yi = s[AA].split(y, nparts=y_nthreads)
tx, ki = s[AA].split(k, nparts=x_nthreads)
ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1])
tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1])
s[AA].reorder(ty, tx, yi, ki)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].pragma(yi, "auto_unroll_max_step", 16)
s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)

_, x, k = s[BB].op.axis
ty, xi = s[BB].split(x, nparts=y_nthreads)
tx, ki = s[BB].split(k, nparts=x_nthreads)
ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1])
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].reorder(ty, tx, xi, ki)
s[BB].pragma(xi, "auto_unroll_max_step", 16)
s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val)

def _callback(op):
if "batch_matmul" in op.tag:
_schedule(op)
_schedule(cfg, op)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_batch_matmul_implement = {
"generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul),
"cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul),
"gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul),
"gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul),
}

def verify_batch_matmul(batch, M, N, K):
Expand Down