Skip to content

Commit

Permalink
Fixes for Ray-based shuffle (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
franklsf95 committed Mar 16, 2023
1 parent 78fb1b0 commit 177790a
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 83 deletions.
33 changes: 16 additions & 17 deletions raysql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def execute_query_stage(
output_partitions_count = stage.get_output_partition_count()
if output_partitions_count == 1:
# reduce stage
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
concurrency = 1

print(
Expand All @@ -39,26 +40,24 @@ def execute_query_stage(
)
)

# Coordinate shuffle partitions
# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = ray.get(child_futures)

def _get_worker_inputs(part: int) -> dict[int, list[ray.ObjectRef]]:
ret = {}
if not use_ray_shuffle:
return ret
return {c: get_child_inputs(part, lst) for c, lst in child_outputs}

def get_child_inputs(
part: int, inputs: list[list[ray.ObjectRef]]
) -> list[ray.ObjectRef]:
def _get_worker_inputs(part: int) -> list[tuple[int, int, int, ray.ObjectRef]]:
ret = []
for lst in inputs:
if isinstance(lst, list):
num_parts = len(lst)
parts_per_worker = num_parts // concurrency
ret.extend(lst[part * parts_per_worker : (part + 1) * parts_per_worker])
else:
ret.append(lst)
if not use_ray_shuffle:
return []
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ret.append((child_stage_id, i, j, f))
elif concurrency == 1 or part == 0:
ret.append((child_stage_id, i, 0, lst))
return ret

# if we are using disk-based shuffle, wait until the child stages to finish
Expand Down
30 changes: 25 additions & 5 deletions raysql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,30 @@ def load_query(n: int) -> str:
def tpchq(ctx: RaySqlContext, q: int = 14):
sql = load_query(q)
result_set = ray.get(ctx.sql.remote(sql))
print("Result:")
print(ResultSet(result_set))
return result_set


use_ray_shuffle = False
ctx = setup_context(use_ray_shuffle)
tpchq(ctx)
def compare(q: int):
ctx = setup_context(False)
result_set_truth = tpchq(ctx, q)

ctx = setup_context(True)
result_set_ray = tpchq(ctx, q)

assert result_set_truth == result_set_ray, (
q,
ResultSet(result_set_truth),
ResultSet(result_set_ray),
)


# use_ray_shuffle = True
# ctx = setup_context(use_ray_shuffle)
# result_set = tpchq(ctx, 1)
# print("Result:")
# print(ResultSet(result_set))

for i in range(1, 22 + 1):
if i == 15:
continue
compare(i)
16 changes: 8 additions & 8 deletions raysql/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ def execute_query_partition(
self,
plan_bytes: bytes,
part: int,
input_partitions_map: dict[int, list[ray.ObjectRef]],
input_partition_refs: list[tuple[int, int, int, ray.ObjectRef]],
) -> list[bytes]:
plan = deserialize_execution_plan(plan_bytes)

if self.debug:
print(
"Worker executing plan {} partition #{} with {} shuffle inputs:\n{}".format(
"Worker executing plan {} partition #{} with shuffle inputs {}".format(
plan.display(),
part,
{i: len(parts) for i, parts in input_partitions_map.items()},
plan.display_indent(),
[(s, i, j) for s, i, j, _ in input_partition_refs],
)
)

input_partitions_map = {
i: ray.get(parts) for i, parts in input_partitions_map.items()
}
input_data = ray.get([f for _, _, _, f in input_partition_refs])
input_partitions = [
(s, j, d) for (s, _, j, _), d in zip(input_partition_refs, input_data)
]
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
result_set = self.ctx.execute_partition(plan, part, input_partitions_map)
result_set = self.ctx.execute_partition(plan, part, input_partitions)

ret = result_set.tobyteslist()
return ret[0] if len(ret) == 1 else ret
92 changes: 65 additions & 27 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::planner::{make_execution_graph, PyExecutionGraph};
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec};
use crate::utils::wait_for_future;
use datafusion::arrow::error::ArrowError;
use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::arrow::ipc::writer::StreamWriter;
use datafusion::arrow::record_batch::RecordBatch;
Expand All @@ -15,14 +16,13 @@ use datafusion::prelude::*;
use datafusion_proto::bytes::{
physical_plan_from_bytes_with_extension_codec, physical_plan_to_bytes_with_extension_codec,
};
use datafusion_python::errors::py_datafusion_err;
use datafusion_python::physical_plan::PyExecutionPlan;
use futures::StreamExt;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
use pyo3::types::{PyBytes, PyList, PyLong, PyTuple};
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::arrow::error::ArrowError;
use datafusion_python::errors::py_datafusion_err;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;

Expand Down Expand Up @@ -133,22 +133,42 @@ fn _set_inputs_for_ray_shuffle_reader(
py: Python,
) -> Result<()> {
if let Some(reader_exec) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
let stage_id = reader_exec.stage_id;
let exec_stage_id = reader_exec.stage_id;
// iterate over inputs, wrap in PyBytes and set as input objects
let input_partitions_map = inputs.as_ref(py).downcast::<PyDict>().map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
match input_partitions_map.get_item(stage_id) {
Some(input_partitions) => {
let input_partitions = input_partitions.downcast::<PyList>().map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let input_objects = input_partitions
.iter()
.map(|input| input.downcast::<PyBytes>().expect("expected PyBytes").as_bytes().to_vec())
.collect();
reader_exec.set_input_partitions(part, input_objects)?;
let input_partitions = inputs
.as_ref(py)
.downcast::<PyList>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
for item in input_partitions.iter() {
let pytuple = item
.downcast::<PyTuple>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let stage_id = pytuple
.get_item(0)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyLong>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.extract::<usize>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
if stage_id != exec_stage_id {
continue;
}
None => {
println!("Warning: No input partitions for stage {}", stage_id);
}
};
let part = pytuple
.get_item(1)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyLong>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.extract::<usize>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let bytes = pytuple
.get_item(2)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyBytes>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.as_bytes()
.to_vec();
reader_exec.add_input_partition(part, bytes)?;
}
} else {
for child in plan.children() {
_set_inputs_for_ray_shuffle_reader(child, part, inputs, py)?;
Expand All @@ -159,7 +179,8 @@ fn _set_inputs_for_ray_shuffle_reader(

impl PyContext {
/// Execute a partition of a query plan. This will typically be executing a shuffle write and
/// write the results to disk, except for the final query stage, which will return the data
/// write the results to disk, except for the final query stage, which will return the data.
/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
fn _execute_partition(
&self,
plan: PyExecutionPlan,
Expand Down Expand Up @@ -207,16 +228,33 @@ impl PyResultSet {
}
}

fn _read_pybytes(pyobj: &PyAny, batches: &mut Vec<PyRecordBatch>) -> PyResult<()> {
let pybytes = pyobj
.downcast::<PyBytes>()
.map_err(|e| py_datafusion_err(e))?;
let reader = StreamReader::try_new(pybytes.as_bytes(), None).map_err(py_datafusion_err)?;
for batch in reader {
let batch = batch.map_err(|e| py_datafusion_err(e))?;
batches.push(PyRecordBatch::new(batch));
}
Ok(())
}

#[pymethods]
impl PyResultSet {
/// This constructor takes either a list of bytes or a single bytes object.
#[new]
fn py_new(py_obj: &PyBytes) -> PyResult<Self> {
let reader = StreamReader::try_new(py_obj.as_bytes(), None).map_err(py_datafusion_err)?;
fn py_new(pyobj: &PyAny) -> PyResult<Self> {
let mut batches = vec![];
for batch in reader {
let batch = batch.map_err(|e| py_datafusion_err(e))?;
batches.push(PyRecordBatch::new(batch));
}
match pyobj.downcast::<PyList>() {
Ok(pylist) => {
for item in pylist.iter() {
_read_pybytes(item, &mut batches)?;
}
Ok(())
}
_ => _read_pybytes(&pyobj, &mut batches),
}?;
Ok(Self { batches })
}

Expand Down Expand Up @@ -249,13 +287,14 @@ impl PyRecordBatch {
impl PyRecordBatch {
#[new]
fn py_new(py_obj: &PyBytes) -> PyResult<Self> {
let reader = StreamReader::try_new(py_obj.as_bytes(), None).map_err(|e| py_datafusion_err(e))?;
let reader =
StreamReader::try_new(py_obj.as_bytes(), None).map_err(|e| py_datafusion_err(e))?;
let mut batches = vec![];
for r in reader {
batches.push(r.map_err(|e| py_datafusion_err(e))?);
}
if let Some(batch) = batches.pop() {
Ok(Self { batch})
Ok(Self { batch })
} else {
Err(py_datafusion_err("no batches"))
}
Expand All @@ -273,7 +312,6 @@ impl PyRecordBatch {
write_batch(&mut buf, &self.batch).map_err(|e| py_datafusion_err(e))?;
Ok(PyBytes::new(py, &buf).into())
}

}

fn write_batch(mut buf: &mut Vec<u8>, batch: &RecordBatch) -> Result<(), ArrowError> {
Expand Down
18 changes: 10 additions & 8 deletions src/shuffle/ray_shuffle/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Statistics;
use datafusion::error::DataFusionError;
use datafusion::execution::context::TaskContext;
use datafusion::physical_expr::expressions::UnKnownColumn;
use datafusion::physical_expr::PhysicalSortExpr;
Expand All @@ -17,7 +18,6 @@ use std::io::Cursor;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use datafusion::error::DataFusionError;

type PartitionId = usize;
type StageId = usize;
Expand Down Expand Up @@ -58,16 +58,18 @@ impl RayShuffleReaderExec {
}
}

pub fn set_input_partitions(&self, partition: PartitionId, input_partitions: Vec<Vec<u8>>) -> Result<(), DataFusionError> {
pub fn add_input_partition(
&self,
partition: PartitionId,
input_partition: Vec<u8>,
) -> Result<(), DataFusionError> {
let mut map = self.input_partitions_map.write().unwrap();
let input_partitions = map.entry(partition).or_insert(vec![]);
input_partitions.push(input_partition);
println!(
"RayShuffleReaderExec[stage={}].execute(input_partition={partition}) is set with {} shuffle inputs",
"RayShuffleReaderExec[stage={}].execute(input_partition={partition}) adding shuffle input",
self.stage_id,
input_partitions.len(),
);
self.input_partitions_map
.write()
.unwrap()
.insert(partition, input_partitions);
Ok(())
}
}
Expand Down
Loading

0 comments on commit 177790a

Please sign in to comment.