Skip to content

Commit

Permalink
Make better use of futures (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 4, 2023
1 parent b952203 commit 1621999
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 136 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
This is a personal research project to evaluate performing distributed SQL queries from Python, using
[Ray](https://www.ray.io/) and [DataFusion](https://github.com/apache/arrow-datafusion).

## Goals

- Demonstrate how easily new systems can be built on top of DataFusion
- Drive requirements for DataFusion's Python bindings
- Create content for an interesting blog post or conference talk

## Non Goals

- Build and support a production system

## Example

Run the following example live in your browser using a Google Colab [notebook](https://colab.research.google.com/drive/1tmSX0Lu6UFh58_-DBUVoyYx6BoXHOszP?usp=sharing).
Expand Down
9 changes: 5 additions & 4 deletions examples/tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
# create some remote Workers
workers = [Worker.remote() for i in range(2)]

# create context and plan a query
ctx = RaySqlContext(workers)
ctx.register_csv('tips', 'tips.csv', True)
# create a remote context and register a table
ctx = RaySqlContext.remote(workers)
ray.get(ctx.register_csv.remote('tips', 'tips.csv', True))

# Parquet is also supported
# ctx.register_parquet('tips', 'tips.parquet')

result_set = ctx.sql('select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker')
result_set = ray.get(ctx.sql.remote('select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker'))
print(result_set)
# print(ray.get(result_set))
7 changes: 5 additions & 2 deletions raysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
import importlib_metadata

from ._raysql_internal import (
Context
Context,
QueryStage,
serialize_execution_plan,
deserialize_execution_plan
)

__version__ = importlib_metadata.version(__name__)

__all__ = [
"Context",
"Worker"
"Worker",
]
94 changes: 59 additions & 35 deletions raysql/context.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,50 @@
import ray
from raysql import Context
from raysql import Context, QueryStage, serialize_execution_plan
import time

@ray.remote
def execute_query_stage(query_stages, stage_id, workers):
plan_bytes = ray.get(query_stages[stage_id]).plan_bytes
stage = QueryStage(stage_id, plan_bytes)

# execute child stages first
child_futures = []
for child_id in stage.get_child_stage_ids():
child_futures.append(execute_query_stage.remote(query_stages, child_id, workers))
ray.get(child_futures)

# if the query stage has a single output partition then we need to execute for the output
# partition, otherwise we need to execute in parallel for each input partition
if stage.get_output_partition_count == 1:
partition_count = 1
else:
partition_count = stage.get_input_partition_count()

print("Scheduling query stage #{} with {} input partitions and {} output partitions".format(stage.id(), partition_count, stage.get_output_partition_count()))

plan_bytes = ray.put(serialize_execution_plan(stage.get_execution_plan()))

# round-robin allocation across workers
futures = []
for part in range(partition_count):
worker_index = part % len(workers)
futures.append(workers[worker_index].execute_query_partition.remote(plan_bytes, part))

print("Waiting for query stage #{} to complete".format(stage.id()))
start = time.time()
result_set = ray.get(futures)
end = time.time()
print("Query stage #{} completed in {} seconds".format(stage.id(), end-start))

return result_set

@ray.remote
class RaySqlContext:

def __init__(self, workers):
self.ctx = Context(len(workers))
self.workers = workers
self.debug = False

def register_csv(self, table_name, path, has_header):
self.ctx.register_csv(table_name, path, has_header)
Expand All @@ -15,42 +53,28 @@ def register_parquet(self, table_name, path):
self.ctx.register_parquet(table_name, path)

def sql(self, sql):
graph = self.ctx.plan(sql)
# recurse down the tree and build a DAG of futures
final_stage = graph.get_final_query_stage()
# schedule execution
return self.execute_query_stage(graph, final_stage)

def execute_query_stage(self, graph, stage):

# TODO make better use of futures here so that more runs in parallel
if self.debug:
print(sql)

# execute child stages first
for child_id in stage.get_child_stage_ids():
child_stage = graph.get_query_stage(child_id)
self.execute_query_stage(graph, child_stage)

# todo what is correct logic here?
if stage.get_output_partition_count == 1:
partition_count = 1
else:
partition_count = stage.get_input_partition_count()

print("Scheduling query stage #{} with {} input partitions and {} output partitions".format(stage.id(), partition_count, stage.get_output_partition_count()))
graph = self.ctx.plan(sql)
final_stage_id = graph.get_final_query_stage().id()

# serialize the plan
plan_bytes = ray.put(self.ctx.serialize_execution_plan(stage.get_execution_plan()))
# serialize the query stages and store in Ray object store
query_stages = []
for i in range(0, final_stage_id+1):
stage = graph.get_query_stage(i)
plan_bytes = serialize_execution_plan(stage.get_execution_plan())
query_stage_serde = QueryStageSerde(i, plan_bytes)
query_stages.append(ray.put(query_stage_serde))

# round-robin allocation across workers
futures = []
for part in range(partition_count):
worker_index = part % len(self.workers)
futures.append(self.workers[worker_index].execute_query_partition.remote(plan_bytes, part))
# schedule execution
future = execute_query_stage.remote(query_stages, final_stage_id, self.workers)
return ray.get(future)

print("Waiting for query stage #{} to complete".format(stage.id()))
start = time.time()
result_set = ray.get(futures)
end = time.time()
print("Query stage #{} completed in {} seconds".format(stage.id(), end-start))
class QueryStageSerde:
def __init__(self, id, plan_bytes):
self.id = id
self.plan_bytes = plan_bytes

return result_set
def __reduce__(self):
return (self.__class__, (self.id, self.plan_bytes))
4 changes: 2 additions & 2 deletions raysql/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ray
from raysql import Context
from raysql import Context, deserialize_execution_plan

@ray.remote
class Worker:
Expand All @@ -8,7 +8,7 @@ def __init__(self):
self.debug = False

def execute_query_partition(self, plan_bytes, part):
plan = self.ctx.deserialize_execution_plan(plan_bytes)
plan = deserialize_execution_plan(plan_bytes)

if self.debug:
print("Executing partition #{}:\n{}".format(part, plan.display_indent()))
Expand Down
26 changes: 15 additions & 11 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,6 @@ impl PyContext {
Ok(PyExecutionGraph::new(graph))
}

fn serialize_execution_plan(&self, plan: PyExecutionPlan) -> PyResult<Vec<u8>> {
let codec = ShuffleCodec {};
Ok(physical_plan_to_bytes_with_extension_codec(plan.plan, &codec)?.to_vec())
}

fn deserialize_execution_plan(&self, bytes: Vec<u8>) -> PyResult<PyExecutionPlan> {
let codec = ShuffleCodec {};
Ok(PyExecutionPlan::new(
physical_plan_from_bytes_with_extension_codec(&bytes, &self.ctx, &codec)?,
))
}

/// Execute a partition of a query plan. This will typically be executing a shuffle write and write the results to disk
pub fn execute_partition(&self, plan: PyExecutionPlan, part: usize) -> PyResultSet {
Expand All @@ -102,6 +91,21 @@ impl PyContext {
}
}

#[pyfunction]
pub fn serialize_execution_plan(plan: PyExecutionPlan) -> PyResult<Vec<u8>> {
let codec = ShuffleCodec {};
Ok(physical_plan_to_bytes_with_extension_codec(plan.plan, &codec)?.to_vec())
}

#[pyfunction]
pub fn deserialize_execution_plan(bytes: Vec<u8>) -> PyResult<PyExecutionPlan> {
let ctx = SessionContext::new();
let codec = ShuffleCodec {};
Ok(PyExecutionPlan::new(
physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?,
))
}

impl PyContext {
/// Execute a partition of a query plan. This will typically be executing a shuffle write and
/// write the results to disk, except for the final query stage, which will return the data
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@ use pyo3::prelude::*;

mod proto;
pub use proto::generated::protobuf;
use crate::context::{serialize_execution_plan, deserialize_execution_plan};

pub mod context;
pub mod planner;
pub mod shuffle;
pub mod utils;
pub mod query_stage;

/// A Python module implemented in Rust.
#[pymodule]
fn _raysql_internal(_py: Python, m: &PyModule) -> PyResult<()> {
// register classes that can be created directly from Python code
m.add_class::<context::PyContext>()?;
m.add_class::<query_stage::PyQueryStage>()?;
m.add_function(wrap_pyfunction!(serialize_execution_plan, m)?)?;
m.add_function(wrap_pyfunction!(deserialize_execution_plan, m)?)?;
Ok(())
}
90 changes: 9 additions & 81 deletions src/planner.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::query_stage::QueryStage;
use crate::query_stage::PyQueryStage;
use crate::shuffle::{ShuffleReaderExec, ShuffleWriterExec};
use datafusion::error::Result;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion::physical_plan::Partitioning;
use datafusion::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
use datafusion_python::physical_plan::PyExecutionPlan;
use log::debug;
use pyo3::prelude::*;
use std::collections::HashMap;
Expand All @@ -26,24 +27,27 @@ impl PyExecutionGraph {

#[pymethods]
impl PyExecutionGraph {

/// Get a list of stages sorted by id
pub fn get_query_stages(&self) -> Vec<PyQueryStage> {
let mut stages = vec![];
for stage in self.graph.query_stages.values() {
stages.push(PyQueryStage::new(stage.clone()));
let max_id = self.graph.get_final_query_stage().id;
for id in 0..=max_id {
stages.push(PyQueryStage::from_rust(self.graph.query_stages.get(&id).unwrap().clone()));
}
stages
}

pub fn get_query_stage(&self, id: usize) -> PyResult<PyQueryStage> {
if let Some(stage) = self.graph.query_stages.get(&id) {
Ok(PyQueryStage::new(stage.clone()))
Ok(PyQueryStage::from_rust(stage.clone()))
} else {
todo!()
}
}

pub fn get_final_query_stage(&self) -> PyQueryStage {
PyQueryStage::new(self.graph.get_final_query_stage())
PyQueryStage::from_rust(self.graph.get_final_query_stage())
}
}

Expand Down Expand Up @@ -91,83 +95,7 @@ impl ExecutionGraph {
}
}

#[pyclass(name = "QueryStage", module = "raysql", subclass)]
pub struct PyQueryStage {
stage: Arc<QueryStage>,
}

impl PyQueryStage {
pub fn new(stage: Arc<QueryStage>) -> Self {
Self { stage }
}
}

#[pymethods]
impl PyQueryStage {
pub fn id(&self) -> usize {
self.stage.id
}

pub fn get_execution_plan(&self) -> PyExecutionPlan {
PyExecutionPlan::new(self.stage.plan.clone())
}

pub fn get_child_stage_ids(&self) -> Vec<usize> {
self.stage.get_child_stage_ids()
}

pub fn get_input_partition_count(&self) -> usize {
self.stage.get_input_partition_count()
}

pub fn get_output_partition_count(&self) -> usize {
self.stage.plan.output_partitioning().partition_count()
}
}

#[derive(Debug)]
pub struct QueryStage {
pub id: usize,
pub plan: Arc<dyn ExecutionPlan>,
}

impl QueryStage {
pub fn new(id: usize, plan: Arc<dyn ExecutionPlan>) -> Self {
Self { id, plan }
}

pub fn get_child_stage_ids(&self) -> Vec<usize> {
let mut ids = vec![];
collect_child_stage_ids(self.plan.as_ref(), &mut ids);
ids
}

/// Get the input partition count. This is the same as the number of concurrent tasks
/// when we schedule this query stage for execution
pub fn get_input_partition_count(&self) -> usize {
collect_input_partition_count(self.plan.as_ref())
}
}

fn collect_child_stage_ids(plan: &dyn ExecutionPlan, ids: &mut Vec<usize>) {
if let Some(shuffle_reader) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
ids.push(shuffle_reader.stage_id);
} else {
for child_plan in plan.children() {
collect_child_stage_ids(child_plan.as_ref(), ids);
}
}
}

fn collect_input_partition_count(plan: &dyn ExecutionPlan) -> usize {
if plan.children().is_empty() {
plan.output_partitioning().partition_count()
} else {
// invariants:
// - all inputs must have the same partition count
collect_input_partition_count(plan.children()[0].as_ref())
}
}

pub fn make_execution_graph(plan: Arc<dyn ExecutionPlan>) -> Result<ExecutionGraph> {
let mut graph = ExecutionGraph::new();
Expand Down
Loading

0 comments on commit 1621999

Please sign in to comment.