Skip to content

Commit

Permalink
Use PyArrow for zero-copy interaction with the Ray Object Store (#36)
Browse files Browse the repository at this point in the history
* Optimize Ray shuffle with zero-copy object store

* remove more clones

* change bytes to pyarrow.array

* revert /tmp

* remove empty_result_set

* remove empty_result_set

* Fix input partition count bug
  • Loading branch information
franklsf95 committed Apr 5, 2023
1 parent ece0d3b commit f985808
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 258 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
target
__pycache__
venv
*.so
*.so
*.log
results-sf*
3 changes: 1 addition & 2 deletions raysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from ._raysql_internal import (
Context,
ExecutionGraph,
QueryStage,
ResultSet,
execute_partition,
serialize_execution_plan,
deserialize_execution_plan,
empty_result_set
)
from .context import RaySqlContext

Expand Down
194 changes: 138 additions & 56 deletions raysql/context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,80 @@
import json
import os
import time
from typing import Iterable

import pyarrow as pa
import ray

import raysql
from raysql import Context, QueryStage, ResultSet, ray_utils
from raysql import Context, ExecutionGraph, QueryStage


@ray.remote
def schedule_execution(
graph: ExecutionGraph,
stage_id: int,
is_final_stage: bool,
) -> list[ray.ObjectRef]:
stage = graph.get_query_stage(stage_id)
# execute child stages first
# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = []
for child_id in stage.get_child_stage_ids():
child_outputs.append((child_id, schedule_execution(graph, child_id, False)))
# child_outputs.append((child_id, schedule_execution(graph, child_id)))

concurrency = stage.get_input_partition_count()
output_partitions_count = stage.get_output_partition_count()
if is_final_stage:
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
concurrency = 1

print(
"Scheduling query stage #{} with {} input partitions and {} output partitions".format(
stage.id(), concurrency, output_partitions_count
)
)

def _get_worker_inputs(
part: int,
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
ids = []
futures = []
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ids.append((child_stage_id, i, j))
futures.append(f)
elif concurrency == 1 or part == 0:
ids.append((child_stage_id, i, 0))
futures.append(lst)
return ids, futures

# schedule the actual execution workers
plan_bytes = raysql.serialize_execution_plan(stage.get_execution_plan())
futures = []
opt = {}
opt["resources"] = {"worker": 1e-3}
opt["num_returns"] = output_partitions_count
for part in range(concurrency):
ids, inputs = _get_worker_inputs(part)
futures.append(
execute_query_partition.options(**opt).remote(
stage_id, plan_bytes, part, ids, *inputs
)
)
return futures


@ray.remote(num_cpus=0)
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int,
num_workers: int,
use_ray_shuffle: bool,
) -> tuple[int, list[ray.ObjectRef]]:
"""
Expand All @@ -22,9 +88,7 @@ def execute_query_stage(
child_futures = []
for child_id in stage.get_child_stage_ids():
child_futures.append(
execute_query_stage.options(**ray_utils.current_node_aff()).remote(
query_stages, child_id, num_workers, use_ray_shuffle
)
execute_query_stage.remote(query_stages, child_id, use_ray_shuffle)
)

# if the query stage has a single output partition then we need to execute for the output
Expand All @@ -46,21 +110,25 @@ def execute_query_stage(
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = ray.get(child_futures)

def _get_worker_inputs(part: int) -> list[tuple[int, int, int, ray.ObjectRef]]:
ret = []
if not use_ray_shuffle:
return []
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ret.append((child_stage_id, i, j, f))
elif concurrency == 1 or part == 0:
ret.append((child_stage_id, i, 0, lst))
return ret
def _get_worker_inputs(
part: int,
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
ids = []
futures = []
if use_ray_shuffle:
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ids.append((child_stage_id, i, j))
futures.append(f)
elif concurrency == 1 or part == 0:
ids.append((child_stage_id, i, 0))
futures.append(lst)
return ids, futures

# if we are using disk-based shuffle, wait until the child stages to finish
# writing the shuffle files to disk first.
Expand All @@ -75,9 +143,10 @@ def _get_worker_inputs(part: int) -> list[tuple[int, int, int, ray.ObjectRef]]:
if use_ray_shuffle:
opt["num_returns"] = output_partitions_count
for part in range(concurrency):
ids, inputs = _get_worker_inputs(part)
futures.append(
execute_query_partition.options(**opt).remote(
plan_bytes, part, _get_worker_inputs(part)
stage_id, plan_bytes, part, ids, *inputs
)
)

Expand All @@ -86,29 +155,39 @@ def _get_worker_inputs(part: int) -> list[tuple[int, int, int, ray.ObjectRef]]:

@ray.remote
def execute_query_partition(
stage_id: int,
plan_bytes: bytes,
part: int,
input_partition_refs: list[tuple[int, int, int, ray.ObjectRef]],
) -> list[bytes]:
input_partition_ids: list[tuple[int, int, int]],
*input_partitions: list[pa.RecordBatch],
) -> Iterable[pa.RecordBatch]:
start_time = time.time()
plan = raysql.deserialize_execution_plan(plan_bytes)
print(
"Worker executing plan {} partition #{} with shuffle inputs {}".format(
plan.display(),
part,
[(s, i, j) for s, i, j, _ in input_partition_refs],
)
)

input_data = ray.get([f for _, _, _, f in input_partition_refs])
input_partitions = [
(s, j, d) for (s, _, j, _), d in zip(input_partition_refs, input_data)
# print(
# "Worker executing plan {} partition #{} with shuffle inputs {}".format(
# plan.display(),
# part,
# input_partition_ids,
# )
# )
partitions = [
(s, j, p) for (s, _, j), p in zip(input_partition_ids, input_partitions)
]
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
result_set = raysql.execute_partition(plan, part, input_partitions)

ret = result_set.tobyteslist()
ret = raysql.execute_partition(plan, part, partitions)
duration = time.time() - start_time
event = {
"cat": f"{stage_id}-{part}",
"name": f"{stage_id}-{part}",
"pid": ray.util.get_node_ip_address(),
"tid": os.getpid(),
"ts": int(start_time * 1_000_000),
"dur": int(duration * 1_000_000),
"ph": "X",
}
print(json.dumps(event), end=",")
return ret[0] if len(ret) == 1 else ret


Expand All @@ -124,30 +203,33 @@ def register_csv(self, table_name: str, path: str, has_header: bool):
def register_parquet(self, table_name: str, path: str):
self.ctx.register_parquet(table_name, path)

def sql(self, sql: str) -> ResultSet:
def sql(self, sql: str) -> pa.RecordBatch:
# TODO we should parse sql and inspect the plan rather than
# perform a string comparison here
if 'create view' in sql or 'drop view' in sql:
sql_str = sql.lower()
if "create view" in sql_str or "drop view" in sql_str:
self.ctx.sql(sql)
return raysql.empty_result_set()
return []

graph = self.ctx.plan(sql)
final_stage_id = graph.get_final_query_stage().id()

# serialize the query stages and store in Ray object store
query_stages = [
raysql.serialize_execution_plan(
graph.get_query_stage(i).get_execution_plan()
if self.use_ray_shuffle:
partitions = schedule_execution(graph, final_stage_id, True)
else:
# serialize the query stages and store in Ray object store
query_stages = [
raysql.serialize_execution_plan(
graph.get_query_stage(i).get_execution_plan()
)
for i in range(final_stage_id + 1)
]
# schedule execution
future = execute_query_stage.remote(
query_stages,
final_stage_id,
self.use_ray_shuffle,
)
for i in range(final_stage_id + 1)
]

# schedule execution
future = execute_query_stage.options(**ray_utils.current_node_aff()).remote(
query_stages, final_stage_id, self.num_workers, self.use_ray_shuffle
)
_, partitions = ray.get(future)
# final stage should have a concurrency of 1
assert len(partitions) == 1, partitions
_, partitions = ray.get(future)
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set
42 changes: 34 additions & 8 deletions raysql/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import time
import os

from pyarrow import csv as pacsv
import ray
from raysql import RaySqlContext, ResultSet
from raysql import RaySqlContext

NUM_CPUS_PER_WORKER = 8

SF = 10
DATA_DIR = f"/mnt/data0/tpch/sf{SF}-parquet"
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
QUERIES_DIR = os.path.join(SCRIPT_DIR, f"../sqlbench-h/queries/sf={SF}")
RESULTS_DIR = f"results-sf{SF}"
TRUTH_DIR = (
"/home/ubuntu/raysort/ray-sql/sqlbench-runners/spark/{RESULTS_DIR}/{RESULTS_DIR}"
)


def setup_context(use_ray_shuffle: bool, num_workers: int = 2) -> RaySqlContext:
Expand Down Expand Up @@ -40,13 +45,30 @@ def tpch_query(ctx: RaySqlContext, q: int = 1):
return result_set


def tpch_timing(ctx: RaySqlContext, q: int = 1, print_result: bool = False):
def tpch_timing(
ctx: RaySqlContext,
q: int = 1,
print_result: bool = False,
write_result: bool = False,
):
sql = load_query(q)
start = time.perf_counter()
result = ctx.sql(sql)
if print_result:
print(ResultSet(result))
end = time.perf_counter()
if print_result:
print("Result:", result)
if isinstance(result, list):
for r in result:
print(r.to_pandas())
else:
print(result.to_pandas())
if write_result:
opt = pacsv.WriteOptions(quoting_style="none")
if isinstance(result, list):
for r in result:
pacsv.write_csv(r, f"{RESULTS_DIR}/q{q}.csv", write_options=opt)
else:
pacsv.write_csv(result, f"{RESULTS_DIR}/q{q}.csv", write_options=opt)
return end - start


Expand All @@ -59,21 +81,25 @@ def compare(q: int):

assert result_set_truth == result_set_ray, (
q,
ResultSet(result_set_truth),
ResultSet(result_set_ray),
result_set_truth,
result_set_ray,
)


def tpch_bench():
ray.init("auto")
num_workers = int(ray.cluster_resources().get("worker", 1)) * NUM_CPUS_PER_WORKER
ctx = setup_context(True, num_workers)
use_ray_shuffle = False
ctx = setup_context(use_ray_shuffle, num_workers)
# t = tpch_timing(ctx, 11, print_result=True)
# print(f"query,{t},{use_ray_shuffle},{num_workers}")
# return
run_id = time.strftime("%Y-%m-%d-%H-%M-%S")
with open(f"results-sf{SF}-{run_id}.csv", "w") as fout:
for i in range(1, 22 + 1):
if i == 15:
continue
result = tpch_timing(ctx, i)
result = tpch_timing(ctx, i, write_result=True)
print(f"query,{i},{result}")
print(f"query,{i},{result}", file=fout, flush=True)

Expand Down
Loading

0 comments on commit f985808

Please sign in to comment.