-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Use Ray object store for shuffle exchange (#28)
- Loading branch information
1 parent
273b5aa
commit 78fb1b0
Showing
16 changed files
with
804 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,116 @@ | ||
import ray | ||
from raysql import Context, QueryStage, serialize_execution_plan | ||
import time | ||
|
||
from raysql import Context, QueryStage, ResultSet, serialize_execution_plan | ||
from raysql.worker import Worker | ||
|
||
|
||
@ray.remote | ||
def execute_query_stage(query_stages, stage_id, workers): | ||
plan_bytes = ray.get(query_stages[stage_id]).plan_bytes | ||
stage = QueryStage(stage_id, plan_bytes) | ||
def execute_query_stage( | ||
query_stages: list[QueryStage], | ||
stage_id: int, | ||
workers: list[Worker], | ||
use_ray_shuffle: bool, | ||
) -> tuple[int, list[ray.ObjectRef]]: | ||
""" | ||
Execute a query stage on the workers. | ||
Returns the stage ID, and a list of futures for the output partitions of the query stage. | ||
""" | ||
stage = QueryStage(stage_id, query_stages[stage_id]) | ||
|
||
# execute child stages first | ||
child_futures = [] | ||
for child_id in stage.get_child_stage_ids(): | ||
child_futures.append(execute_query_stage.remote(query_stages, child_id, workers)) | ||
ray.get(child_futures) | ||
child_futures.append( | ||
execute_query_stage.remote(query_stages, child_id, workers, use_ray_shuffle) | ||
) | ||
|
||
# if the query stage has a single output partition then we need to execute for the output | ||
# partition, otherwise we need to execute in parallel for each input partition | ||
concurrency = stage.get_input_partition_count() | ||
if stage.get_output_partition_count() == 1: | ||
output_partitions_count = stage.get_output_partition_count() | ||
if output_partitions_count == 1: | ||
# reduce stage | ||
concurrency = 1 | ||
|
||
print("Scheduling query stage #{} with {} input partitions and {} output partitions".format(stage.id(), stage.get_input_partition_count(), stage.get_output_partition_count())) | ||
|
||
plan_bytes = ray.put(serialize_execution_plan(stage.get_execution_plan())) | ||
print( | ||
"Scheduling query stage #{} with {} input partitions and {} output partitions".format( | ||
stage.id(), concurrency, output_partitions_count | ||
) | ||
) | ||
|
||
# Coordinate shuffle partitions | ||
child_outputs = ray.get(child_futures) | ||
|
||
def _get_worker_inputs(part: int) -> dict[int, list[ray.ObjectRef]]: | ||
ret = {} | ||
if not use_ray_shuffle: | ||
return ret | ||
return {c: get_child_inputs(part, lst) for c, lst in child_outputs} | ||
|
||
def get_child_inputs( | ||
part: int, inputs: list[list[ray.ObjectRef]] | ||
) -> list[ray.ObjectRef]: | ||
ret = [] | ||
for lst in inputs: | ||
if isinstance(lst, list): | ||
num_parts = len(lst) | ||
parts_per_worker = num_parts // concurrency | ||
ret.extend(lst[part * parts_per_worker : (part + 1) * parts_per_worker]) | ||
else: | ||
ret.append(lst) | ||
return ret | ||
|
||
# if we are using disk-based shuffle, wait until the child stages to finish | ||
# writing the shuffle files to disk first. | ||
if not use_ray_shuffle: | ||
ray.get([f for _, lst in child_outputs for f in lst]) | ||
|
||
# round-robin allocation across workers | ||
plan_bytes = serialize_execution_plan(stage.get_execution_plan()) | ||
futures = [] | ||
for part in range(concurrency): | ||
worker_index = part % len(workers) | ||
futures.append(workers[worker_index].execute_query_partition.remote(plan_bytes, part)) | ||
opt = {} | ||
if use_ray_shuffle: | ||
opt["num_returns"] = output_partitions_count | ||
futures.append( | ||
workers[worker_index] | ||
.execute_query_partition.options(**opt) | ||
.remote(plan_bytes, part, _get_worker_inputs(part)) | ||
) | ||
|
||
print("Waiting for query stage #{} to complete".format(stage.id())) | ||
start = time.time() | ||
result_set = ray.get(futures) | ||
end = time.time() | ||
print("Query stage #{} completed in {} seconds".format(stage.id(), end-start)) | ||
return stage_id, futures | ||
|
||
return result_set | ||
|
||
@ray.remote | ||
class RaySqlContext: | ||
|
||
def __init__(self, workers): | ||
self.ctx = Context(len(workers)) | ||
def __init__(self, workers: list[Worker], use_ray_shuffle: bool): | ||
self.ctx = Context(len(workers), use_ray_shuffle) | ||
self.workers = workers | ||
self.debug = False | ||
self.use_ray_shuffle = use_ray_shuffle | ||
|
||
def register_csv(self, table_name, path, has_header): | ||
def register_csv(self, table_name: str, path: str, has_header: bool): | ||
self.ctx.register_csv(table_name, path, has_header) | ||
|
||
def register_parquet(self, table_name, path): | ||
def register_parquet(self, table_name: str, path: str): | ||
self.ctx.register_parquet(table_name, path) | ||
|
||
def sql(self, sql): | ||
if self.debug: | ||
print(sql) | ||
|
||
def sql(self, sql: str) -> ResultSet: | ||
graph = self.ctx.plan(sql) | ||
final_stage_id = graph.get_final_query_stage().id() | ||
|
||
# serialize the query stages and store in Ray object store | ||
query_stages = [] | ||
for i in range(0, final_stage_id+1): | ||
stage = graph.get_query_stage(i) | ||
plan_bytes = serialize_execution_plan(stage.get_execution_plan()) | ||
query_stage_serde = QueryStageSerde(i, plan_bytes) | ||
query_stages.append(ray.put(query_stage_serde)) | ||
query_stages = [ | ||
serialize_execution_plan(graph.get_query_stage(i).get_execution_plan()) | ||
for i in range(final_stage_id + 1) | ||
] | ||
|
||
# schedule execution | ||
future = execute_query_stage.remote(query_stages, final_stage_id, self.workers) | ||
return ray.get(future) | ||
|
||
class QueryStageSerde: | ||
def __init__(self, id, plan_bytes): | ||
self.id = id | ||
self.plan_bytes = plan_bytes | ||
|
||
def __reduce__(self): | ||
return (self.__class__, (self.id, self.plan_bytes)) | ||
future = execute_query_stage.remote( | ||
query_stages, final_stage_id, self.workers, self.use_ray_shuffle | ||
) | ||
_, partitions = ray.get(future) | ||
# TODO(@lsf): we only support a single output partition for now? | ||
result = ray.get(partitions[0]) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
import ray | ||
from raysql import ResultSet | ||
from raysql.context import RaySqlContext | ||
from raysql.worker import Worker | ||
|
||
DATA_DIR = "/home/ubuntu/tpch/sf1-parquet" | ||
# DATA_DIR = "/home/ubuntu/sf10-parquet" | ||
|
||
ray.init() | ||
# ray.init(local_mode=True) | ||
|
||
|
||
def setup_context(use_ray_shuffle: bool) -> RaySqlContext: | ||
num_workers = 2 | ||
# num_workers = os.cpu_count() | ||
workers = [Worker.remote() for _ in range(num_workers)] | ||
ctx = RaySqlContext.remote(workers, use_ray_shuffle) | ||
register_tasks = [] | ||
register_tasks.append(ctx.register_csv.remote("tips", "examples/tips.csv", True)) | ||
for table in [ | ||
"customer", | ||
"lineitem", | ||
"nation", | ||
"orders", | ||
"part", | ||
"partsupp", | ||
"region", | ||
"supplier", | ||
]: | ||
register_tasks.append( | ||
ctx.register_parquet.remote(table, f"{DATA_DIR}/{table}.parquet") | ||
) | ||
return ctx | ||
|
||
|
||
def load_query(n: int) -> str: | ||
with open(f"testdata/queries/q{n}.sql") as fin: | ||
return fin.read() | ||
|
||
|
||
def tpchq(ctx: RaySqlContext, q: int = 14): | ||
sql = load_query(q) | ||
result_set = ray.get(ctx.sql.remote(sql)) | ||
print("Result:") | ||
print(ResultSet(result_set)) | ||
|
||
|
||
use_ray_shuffle = False | ||
ctx = setup_context(use_ray_shuffle) | ||
tpchq(ctx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,39 @@ | ||
import ray | ||
|
||
from raysql import Context, deserialize_execution_plan | ||
|
||
|
||
@ray.remote | ||
class Worker: | ||
def __init__(self): | ||
self.ctx = Context(1) | ||
self.debug = False | ||
self.ctx = Context(1, False) | ||
self.debug = True | ||
|
||
def execute_query_partition(self, plan_bytes, part): | ||
def execute_query_partition( | ||
self, | ||
plan_bytes: bytes, | ||
part: int, | ||
input_partitions_map: dict[int, list[ray.ObjectRef]], | ||
) -> list[bytes]: | ||
plan = deserialize_execution_plan(plan_bytes) | ||
|
||
if self.debug: | ||
print("Executing partition #{}:\n{}".format(part, plan.display_indent())) | ||
print( | ||
"Worker executing plan {} partition #{} with {} shuffle inputs:\n{}".format( | ||
plan.display(), | ||
part, | ||
{i: len(parts) for i, parts in input_partitions_map.items()}, | ||
plan.display_indent(), | ||
) | ||
) | ||
|
||
input_partitions_map = { | ||
i: ray.get(parts) for i, parts in input_partitions_map.items() | ||
} | ||
# This is delegating to DataFusion for execution, but this would be a good place | ||
# to plug in other execution engines by translating the plan into another engine's plan | ||
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait) | ||
results = self.ctx.execute_partition(plan, part) | ||
result_set = self.ctx.execute_partition(plan, part, input_partitions_map) | ||
|
||
# TODO: return results here instead of string representation of results | ||
return "{}".format(results) | ||
ret = result_set.tobyteslist() | ||
return ret[0] if len(ret) == 1 else ret |
Oops, something went wrong.