Skip to content

Commit

Permalink
[WIP] Use Ray object store for shuffle exchange (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
franklsf95 committed Mar 12, 2023
1 parent 273b5aa commit 78fb1b0
Show file tree
Hide file tree
Showing 16 changed files with 804 additions and 95 deletions.
3 changes: 2 additions & 1 deletion raysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from ._raysql_internal import (
Context,
QueryStage,
ResultSet,
serialize_execution_plan,
deserialize_execution_plan
deserialize_execution_plan,
)

__version__ = importlib_metadata.version(__name__)
Expand Down
124 changes: 80 additions & 44 deletions raysql/context.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,116 @@
import ray
from raysql import Context, QueryStage, serialize_execution_plan
import time

from raysql import Context, QueryStage, ResultSet, serialize_execution_plan
from raysql.worker import Worker


@ray.remote
def execute_query_stage(query_stages, stage_id, workers):
plan_bytes = ray.get(query_stages[stage_id]).plan_bytes
stage = QueryStage(stage_id, plan_bytes)
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int,
workers: list[Worker],
use_ray_shuffle: bool,
) -> tuple[int, list[ray.ObjectRef]]:
"""
Execute a query stage on the workers.
Returns the stage ID, and a list of futures for the output partitions of the query stage.
"""
stage = QueryStage(stage_id, query_stages[stage_id])

# execute child stages first
child_futures = []
for child_id in stage.get_child_stage_ids():
child_futures.append(execute_query_stage.remote(query_stages, child_id, workers))
ray.get(child_futures)
child_futures.append(
execute_query_stage.remote(query_stages, child_id, workers, use_ray_shuffle)
)

# if the query stage has a single output partition then we need to execute for the output
# partition, otherwise we need to execute in parallel for each input partition
concurrency = stage.get_input_partition_count()
if stage.get_output_partition_count() == 1:
output_partitions_count = stage.get_output_partition_count()
if output_partitions_count == 1:
# reduce stage
concurrency = 1

print("Scheduling query stage #{} with {} input partitions and {} output partitions".format(stage.id(), stage.get_input_partition_count(), stage.get_output_partition_count()))

plan_bytes = ray.put(serialize_execution_plan(stage.get_execution_plan()))
print(
"Scheduling query stage #{} with {} input partitions and {} output partitions".format(
stage.id(), concurrency, output_partitions_count
)
)

# Coordinate shuffle partitions
child_outputs = ray.get(child_futures)

def _get_worker_inputs(part: int) -> dict[int, list[ray.ObjectRef]]:
ret = {}
if not use_ray_shuffle:
return ret
return {c: get_child_inputs(part, lst) for c, lst in child_outputs}

def get_child_inputs(
part: int, inputs: list[list[ray.ObjectRef]]
) -> list[ray.ObjectRef]:
ret = []
for lst in inputs:
if isinstance(lst, list):
num_parts = len(lst)
parts_per_worker = num_parts // concurrency
ret.extend(lst[part * parts_per_worker : (part + 1) * parts_per_worker])
else:
ret.append(lst)
return ret

# if we are using disk-based shuffle, wait until the child stages to finish
# writing the shuffle files to disk first.
if not use_ray_shuffle:
ray.get([f for _, lst in child_outputs for f in lst])

# round-robin allocation across workers
plan_bytes = serialize_execution_plan(stage.get_execution_plan())
futures = []
for part in range(concurrency):
worker_index = part % len(workers)
futures.append(workers[worker_index].execute_query_partition.remote(plan_bytes, part))
opt = {}
if use_ray_shuffle:
opt["num_returns"] = output_partitions_count
futures.append(
workers[worker_index]
.execute_query_partition.options(**opt)
.remote(plan_bytes, part, _get_worker_inputs(part))
)

print("Waiting for query stage #{} to complete".format(stage.id()))
start = time.time()
result_set = ray.get(futures)
end = time.time()
print("Query stage #{} completed in {} seconds".format(stage.id(), end-start))
return stage_id, futures

return result_set

@ray.remote
class RaySqlContext:

def __init__(self, workers):
self.ctx = Context(len(workers))
def __init__(self, workers: list[Worker], use_ray_shuffle: bool):
self.ctx = Context(len(workers), use_ray_shuffle)
self.workers = workers
self.debug = False
self.use_ray_shuffle = use_ray_shuffle

def register_csv(self, table_name, path, has_header):
def register_csv(self, table_name: str, path: str, has_header: bool):
self.ctx.register_csv(table_name, path, has_header)

def register_parquet(self, table_name, path):
def register_parquet(self, table_name: str, path: str):
self.ctx.register_parquet(table_name, path)

def sql(self, sql):
if self.debug:
print(sql)

def sql(self, sql: str) -> ResultSet:
graph = self.ctx.plan(sql)
final_stage_id = graph.get_final_query_stage().id()

# serialize the query stages and store in Ray object store
query_stages = []
for i in range(0, final_stage_id+1):
stage = graph.get_query_stage(i)
plan_bytes = serialize_execution_plan(stage.get_execution_plan())
query_stage_serde = QueryStageSerde(i, plan_bytes)
query_stages.append(ray.put(query_stage_serde))
query_stages = [
serialize_execution_plan(graph.get_query_stage(i).get_execution_plan())
for i in range(final_stage_id + 1)
]

# schedule execution
future = execute_query_stage.remote(query_stages, final_stage_id, self.workers)
return ray.get(future)

class QueryStageSerde:
def __init__(self, id, plan_bytes):
self.id = id
self.plan_bytes = plan_bytes

def __reduce__(self):
return (self.__class__, (self.id, self.plan_bytes))
future = execute_query_stage.remote(
query_stages, final_stage_id, self.workers, self.use_ray_shuffle
)
_, partitions = ray.get(future)
# TODO(@lsf): we only support a single output partition for now?
result = ray.get(partitions[0])
return result
52 changes: 52 additions & 0 deletions raysql/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import ray
from raysql import ResultSet
from raysql.context import RaySqlContext
from raysql.worker import Worker

DATA_DIR = "/home/ubuntu/tpch/sf1-parquet"
# DATA_DIR = "/home/ubuntu/sf10-parquet"

ray.init()
# ray.init(local_mode=True)


def setup_context(use_ray_shuffle: bool) -> RaySqlContext:
num_workers = 2
# num_workers = os.cpu_count()
workers = [Worker.remote() for _ in range(num_workers)]
ctx = RaySqlContext.remote(workers, use_ray_shuffle)
register_tasks = []
register_tasks.append(ctx.register_csv.remote("tips", "examples/tips.csv", True))
for table in [
"customer",
"lineitem",
"nation",
"orders",
"part",
"partsupp",
"region",
"supplier",
]:
register_tasks.append(
ctx.register_parquet.remote(table, f"{DATA_DIR}/{table}.parquet")
)
return ctx


def load_query(n: int) -> str:
with open(f"testdata/queries/q{n}.sql") as fin:
return fin.read()


def tpchq(ctx: RaySqlContext, q: int = 14):
sql = load_query(q)
result_set = ray.get(ctx.sql.remote(sql))
print("Result:")
print(ResultSet(result_set))


use_ray_shuffle = False
ctx = setup_context(use_ray_shuffle)
tpchq(ctx)
4 changes: 2 additions & 2 deletions raysql/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from raysql import Context

def test():
ctx = Context()
ctx = Context(1, False)
ctx.register_csv('tips', 'examples/tips.csv', True)
ctx.plan("SELECT * FROM tips")
ctx.plan("SELECT * FROM tips")
31 changes: 24 additions & 7 deletions raysql/worker.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
import ray

from raysql import Context, deserialize_execution_plan


@ray.remote
class Worker:
def __init__(self):
self.ctx = Context(1)
self.debug = False
self.ctx = Context(1, False)
self.debug = True

def execute_query_partition(self, plan_bytes, part):
def execute_query_partition(
self,
plan_bytes: bytes,
part: int,
input_partitions_map: dict[int, list[ray.ObjectRef]],
) -> list[bytes]:
plan = deserialize_execution_plan(plan_bytes)

if self.debug:
print("Executing partition #{}:\n{}".format(part, plan.display_indent()))
print(
"Worker executing plan {} partition #{} with {} shuffle inputs:\n{}".format(
plan.display(),
part,
{i: len(parts) for i, parts in input_partitions_map.items()},
plan.display_indent(),
)
)

input_partitions_map = {
i: ray.get(parts) for i, parts in input_partitions_map.items()
}
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
results = self.ctx.execute_partition(plan, part)
result_set = self.ctx.execute_partition(plan, part, input_partitions_map)

# TODO: return results here instead of string representation of results
return "{}".format(results)
ret = result_set.tobyteslist()
return ret[0] if len(ret) == 1 else ret
Loading

0 comments on commit 78fb1b0

Please sign in to comment.