Skip to content

Commit

Permalink
Make distributed execution work (#33)
Browse files Browse the repository at this point in the history
* Make distributed execution work

* fix tips.py

* fixes; incorporate changes from #32
  • Loading branch information
franklsf95 committed Mar 23, 2023
1 parent 8ce17e2 commit 9fcf28e
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 169 deletions.
29 changes: 16 additions & 13 deletions examples/tips.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import os

import ray
from raysql.context import RaySqlContext
from raysql.worker import Worker

# Start our cluster
ray.init()
from raysql import RaySqlContext, ResultSet

# create some remote Workers
workers = [Worker.remote() for i in range(2)]
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))

# create a remote context and register a table
ctx = RaySqlContext.remote(workers)
ray.get(ctx.register_csv.remote('tips', 'tips.csv', True))
# Start a local cluster
ray.init(resources={"worker": 1})

# Parquet is also supported
# ctx.register_parquet('tips', 'tips.parquet')
# Create a context and register a table
ctx = RaySqlContext(2, use_ray_shuffle=True)
# Register either a CSV or Parquet file
# ctx.register_csv("tips", f"{SCRIPT_DIR}/tips.csv", True)
ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")

result_set = ray.get(ctx.sql.remote('select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker'))
print(result_set)
result_set = ctx.sql(
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker"
)
print("Result:")
print(ResultSet(result_set))
11 changes: 2 additions & 9 deletions raysql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from abc import ABCMeta, abstractmethod
from typing import List


try:
import importlib.metadata as importlib_metadata
except ImportError:
Expand All @@ -11,13 +7,10 @@
Context,
QueryStage,
ResultSet,
execute_partition,
serialize_execution_plan,
deserialize_execution_plan,
)
from .context import RaySqlContext

__version__ = importlib_metadata.version(__name__)

__all__ = [
"Context",
"Worker",
]
73 changes: 53 additions & 20 deletions raysql/context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import ray

from raysql import Context, QueryStage, ResultSet, serialize_execution_plan
from raysql.worker import Worker
import raysql
from raysql import Context, QueryStage, ResultSet, ray_utils


@ray.remote
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int,
workers: list[Worker],
num_workers: int,
use_ray_shuffle: bool,
) -> tuple[int, list[ray.ObjectRef]]:
"""
Expand All @@ -22,7 +22,9 @@ def execute_query_stage(
child_futures = []
for child_id in stage.get_child_stage_ids():
child_futures.append(
execute_query_stage.remote(query_stages, child_id, workers, use_ray_shuffle)
execute_query_stage.options(**ray_utils.current_node_aff()).remote(
query_stages, child_id, num_workers, use_ray_shuffle
)
)

# if the query stage has a single output partition then we need to execute for the output
Expand Down Expand Up @@ -65,28 +67,55 @@ def _get_worker_inputs(part: int) -> list[tuple[int, int, int, ray.ObjectRef]]:
if not use_ray_shuffle:
ray.get([f for _, lst in child_outputs for f in lst])

# round-robin allocation across workers
plan_bytes = serialize_execution_plan(stage.get_execution_plan())
# schedule the actual execution workers
plan_bytes = raysql.serialize_execution_plan(stage.get_execution_plan())
futures = []
opt = {}
opt["resources"] = {"worker": 1e-3}
if use_ray_shuffle:
opt["num_returns"] = output_partitions_count
for part in range(concurrency):
worker_index = part % len(workers)
opt = {}
if use_ray_shuffle:
opt["num_returns"] = output_partitions_count
futures.append(
workers[worker_index]
.execute_query_partition.options(**opt)
.remote(plan_bytes, part, _get_worker_inputs(part))
execute_query_partition.options(**opt).remote(
plan_bytes, part, _get_worker_inputs(part)
)
)

return stage_id, futures


@ray.remote
def execute_query_partition(
plan_bytes: bytes,
part: int,
input_partition_refs: list[tuple[int, int, int, ray.ObjectRef]],
) -> list[bytes]:
plan = raysql.deserialize_execution_plan(plan_bytes)
print(
"Worker executing plan {} partition #{} with shuffle inputs {}".format(
plan.display(),
part,
[(s, i, j) for s, i, j, _ in input_partition_refs],
)
)

input_data = ray.get([f for _, _, _, f in input_partition_refs])
input_partitions = [
(s, j, d) for (s, _, j, _), d in zip(input_partition_refs, input_data)
]
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
result_set = raysql.execute_partition(plan, part, input_partitions)

ret = result_set.tobyteslist()
return ret[0] if len(ret) == 1 else ret


class RaySqlContext:
def __init__(self, workers: list[Worker], use_ray_shuffle: bool = False):
self.ctx = Context(len(workers), use_ray_shuffle)
self.workers = workers
def __init__(self, num_workers: int = 1, use_ray_shuffle: bool = False):
self.ctx = Context(num_workers, use_ray_shuffle)
self.num_workers = num_workers
self.use_ray_shuffle = use_ray_shuffle

def register_csv(self, table_name: str, path: str, has_header: bool):
Expand All @@ -101,14 +130,18 @@ def sql(self, sql: str) -> ResultSet:

# serialize the query stages and store in Ray object store
query_stages = [
serialize_execution_plan(graph.get_query_stage(i).get_execution_plan())
raysql.serialize_execution_plan(
graph.get_query_stage(i).get_execution_plan()
)
for i in range(final_stage_id + 1)
]

# schedule execution
future = execute_query_stage.remote(
query_stages, final_stage_id, self.workers, self.use_ray_shuffle
future = execute_query_stage.options(**ray_utils.current_node_aff()).remote(
query_stages, final_stage_id, self.num_workers, self.use_ray_shuffle
)
_, partitions = ray.get(future)
result_set = ray.get(partitions)
# final stage should have a concurrency of 1
assert len(partitions) == 1, partitions
result_set = ray.get(partitions[0])
return result_set
71 changes: 40 additions & 31 deletions raysql/main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import time
import os

import ray
from raysql import ResultSet
from raysql.context import RaySqlContext
from raysql.worker import Worker
from raysql import RaySqlContext, ResultSet

DATA_DIR = "/home/ubuntu/tpch/sf1-parquet"
# DATA_DIR = "/home/ubuntu/sf10-parquet"
NUM_CPUS_PER_WORKER = 8

ray.init()
# ray.init(local_mode=True)
SF = 10
DATA_DIR = f"/mnt/data0/tpch/sf{SF}-parquet"
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
QUERIES_DIR = os.path.join(SCRIPT_DIR, f"../sqlbench-h/queries/sf={SF}")


def setup_context(use_ray_shuffle: bool) -> RaySqlContext:
num_workers = 2
# num_workers = os.cpu_count()
workers = [Worker.remote() for _ in range(num_workers)]
ctx = RaySqlContext.remote(workers, use_ray_shuffle)
register_tasks = []
register_tasks.append(ctx.register_csv.remote("tips", "examples/tips.csv", True))
def setup_context(use_ray_shuffle: bool, num_workers: int = 2) -> RaySqlContext:
print(f"Using {num_workers} workers")
ctx = RaySqlContext(num_workers, use_ray_shuffle)
for table in [
"customer",
"lineitem",
Expand All @@ -29,29 +25,37 @@ def setup_context(use_ray_shuffle: bool) -> RaySqlContext:
"region",
"supplier",
]:
register_tasks.append(
ctx.register_parquet.remote(table, f"{DATA_DIR}/{table}.parquet")
)
ctx.register_parquet(table, f"{DATA_DIR}/{table}.parquet")
return ctx


def load_query(n: int) -> str:
with open(f"testdata/queries/q{n}.sql") as fin:
with open(f"{QUERIES_DIR}/q{n}.sql") as fin:
return fin.read()


def tpchq(ctx: RaySqlContext, q: int = 14):
def tpch_query(ctx: RaySqlContext, q: int = 1):
sql = load_query(q)
result_set = ray.get(ctx.sql.remote(sql))
result_set = ctx.sql(sql)
return result_set


def tpch_timing(ctx: RaySqlContext, q: int = 1, print_result: bool = False):
sql = load_query(q)
start = time.perf_counter()
result = ctx.sql(sql)
if print_result:
print(ResultSet(result))
end = time.perf_counter()
return end - start


def compare(q: int):
ctx = setup_context(False)
result_set_truth = tpchq(ctx, q)
result_set_truth = tpch_query(ctx, q)

ctx = setup_context(True)
result_set_ray = tpchq(ctx, q)
result_set_ray = tpch_query(ctx, q)

assert result_set_truth == result_set_ray, (
q,
Expand All @@ -60,13 +64,18 @@ def compare(q: int):
)


# use_ray_shuffle = True
# ctx = setup_context(use_ray_shuffle)
# result_set = tpchq(ctx, 1)
# print("Result:")
# print(ResultSet(result_set))
def tpch_bench():
ray.init("auto")
num_workers = int(ray.cluster_resources().get("worker", 1)) * NUM_CPUS_PER_WORKER
ctx = setup_context(True, num_workers)
run_id = time.strftime("%Y-%m-%d-%H-%M-%S")
with open(f"results-sf{SF}-{run_id}.csv", "w") as fout:
for i in range(1, 22 + 1):
if i == 15:
continue
result = tpch_timing(ctx, i)
print(f"query,{i},{result}")
print(f"query,{i},{result}", file=fout, flush=True)


for i in range(1, 22 + 1):
if i == 15:
continue
compare(i)
tpch_bench()
14 changes: 14 additions & 0 deletions raysql/ray_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import ray


def node_aff(node_id: ray.NodeID, *, soft: bool = False) -> dict:
return {
"scheduling_strategy": ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=soft,
)
}


def current_node_aff() -> dict:
return node_aff(ray.get_runtime_context().get_node_id())
39 changes: 0 additions & 39 deletions raysql/worker.py

This file was deleted.

3 changes: 2 additions & 1 deletion requirements-in.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
black
flake8
isort
maturin
maturin[patchelf]
mypy
numpy
pyarrow
pytest
ray==2.3.0
toml
importlib_metadata; python_version < "3.8"
Loading

0 comments on commit 9fcf28e

Please sign in to comment.