Skip to content

Commit

Permalink
Implement shuffle more fully (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Jan 28, 2023
1 parent b99b73f commit 99ccc19
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 26 deletions.
5 changes: 4 additions & 1 deletion examples/tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@
# create context and plan a query
ctx = RaySqlContext(workers)
ctx.register_csv('tips', 'tips.csv', True)
ctx.sql('select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker')

ctx.sql('select day, sum(total_bill) from tips group by day')

#ctx.sql('select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker')
6 changes: 4 additions & 2 deletions raysql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ def execute_query_stage(self, graph, stage):
child_stage = graph.get_query_stage(child_id)
self.execute_query_stage(graph, child_stage)

print("Scheduling query stage #{}".format(stage.id()))
partition_count = stage.get_input_partition_count()
print("Scheduling query stage #{} with {} input partitions and {} output partitions".format(stage.id(), partition_count, stage.get_output_partition_count()))

# serialize the plan
plan_bytes = self.ctx.serialize_execution_plan(stage.get_execution_plan())

# round-robin allocation across workers
futures = []
for part in range(stage.get_input_partition_count()):
for part in range(partition_count):
worker_index = part % len(self.workers)
print("Asking worker {} to execute partition {}".format(worker_index, part))
futures.append(self.workers[worker_index].execute_query_partition.remote(plan_bytes, part))

print("Waiting for query stage #{} to complete".format(stage.id()))
Expand Down
12 changes: 7 additions & 5 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::utils::wait_for_future;
use datafusion::arrow::array::Int32Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::displayable;
Expand All @@ -28,8 +29,9 @@ pub struct PyContext {
impl PyContext {
#[new]
pub fn new() -> Self {
let config = SessionConfig::default().with_target_partitions(4);
Self {
ctx: SessionContext::default(),
ctx: SessionContext::with_config(config),
}
}

Expand Down Expand Up @@ -84,8 +86,6 @@ impl PyContext {

/// Execute a partition of a query plan. This will typically be executing a shuffle write and write the results to disk
pub fn execute_partition(&self, plan: PyExecutionPlan, part: usize) -> PyResult<()> {
println!("Executing: {}", plan.display_indent());

let ctx = Arc::new(TaskContext::new(
"task_id".to_string(),
"session_id".to_string(),
Expand All @@ -100,13 +100,15 @@ impl PyContext {

let fut = rt.spawn(async move {
let mut stream = plan.plan.execute(part, ctx)?;
let mut results = vec![];
while let Some(result) = stream.next().await {
let input_batch = result?;
println!("received batch with {} rows", input_batch.num_rows());
results.push(input_batch);
}

// TODO remove this dummy batch
println!("Results:\n{}", pretty_format_batches(&results)?);

// TODO remove this dummy batch
// create a dummy batch to return - later this could be metadata about the
// shuffle partitions that were written out
let schema = Arc::new(Schema::new(vec![Field::new("foo", DataType::Int32, true)]));
Expand Down
4 changes: 4 additions & 0 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ impl PyQueryStage {
pub fn get_input_partition_count(&self) -> usize {
self.stage.get_input_partition_count()
}

pub fn get_output_partition_count(&self) -> usize {
self.stage.plan.output_partitioning().partition_count()
}
}

#[derive(Debug)]
Expand Down
4 changes: 2 additions & 2 deletions src/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
partition: 0,
schema: Some(schema),
num_output_partitions: 1,
shuffle_dir: "/tmp/raysql-shuffle".to_string(),
shuffle_dir: "/tmp/raysql".to_string(), // TODO remove hard-coded path
};
PlanType::ShuffleReader(reader)
} else if let Some(writer) = node.as_any().downcast_ref::<ShuffleWriterExec>() {
Expand All @@ -75,7 +75,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
plan: Some(plan),
partition_expr: vec![],
num_output_partitions: 1,
shuffle_dir: "/tmp/raysql-shuffle".to_string(),
shuffle_dir: "/tmp/raysql".to_string(), // TODO remove hard-coded path
};
PlanType::ShuffleWriter(writer)
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/shuffle/reader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use datafusion::arrow::datatypes::{Schema, SchemaRef};
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Statistics;
Expand All @@ -8,7 +8,6 @@ use datafusion::physical_plan::{
DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
};
use futures::Stream;
use prost::Message;
use std::any::Any;
use std::fmt::Formatter;
use std::fs::File;
Expand Down Expand Up @@ -63,7 +62,8 @@ impl ExecutionPlan for ShuffleReaderExec {
partition: usize,
_context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
let file = format!("/tmp/raysql/{}_{partition}.arrow", self.stage_id);
// TODO remove hard-coded path
let file = format!("/tmp/raysql/stage_{}_part_{partition}.arrow", self.stage_id);
println!("Shuffle reader reading from {file}");
let reader = FileReader::try_new(File::open(&file)?, None)?;
Ok(Box::pin(LocalShuffleStream::new(reader)))
Expand Down
107 changes: 94 additions & 13 deletions src/shuffle/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use datafusion::arrow::array::Int32Array;
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::ipc::writer::FileWriter;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::common::{Result, Statistics};
use datafusion::execution::context::TaskContext;
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::common::batch_byte_size;
use datafusion::physical_plan::common::{batch_byte_size, IPCWriter};
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder};
use datafusion::physical_plan::repartition::BatchPartitioner;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
metrics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
Expand All @@ -19,6 +21,7 @@ use futures::TryStreamExt;
use std::any::Any;
use std::fmt::Formatter;
use std::fs::File;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;

Expand Down Expand Up @@ -49,7 +52,7 @@ impl ExecutionPlan for ShuffleWriterExec {
}

fn output_partitioning(&self) -> Partitioning {
todo!()
self.plan.output_partitioning()
}

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
Expand All @@ -73,20 +76,98 @@ impl ExecutionPlan for ShuffleWriterExec {

fn execute(
&self,
partition: usize,
input_partition: usize,
context: Arc<TaskContext>,
) -> datafusion::common::Result<SendableRecordBatchStream> {
let mut stream = self.plan.execute(partition, context)?;
let file = format!("/tmp/raysql/{}_{partition}.arrow", self.stage_id);
let write_time = MetricBuilder::new(&self.metrics).subset_time("write_time", partition);
println!("ShuffleWriteExec::execute(input_partition={input_partition})");

let mut stream = self.plan.execute(input_partition, context)?;
let write_time =
MetricBuilder::new(&self.metrics).subset_time("write_time", input_partition);
let repart_time =
MetricBuilder::new(&self.metrics).subset_time("repart_time", input_partition);

let stage_id = self.stage_id;
let partitioning = self.output_partitioning();
let partition_count = partitioning.partition_count();

let results = async move {
// stream the results from the query
println!("Executing query and writing results to {file}");
let stats = write_stream_to_disk(&mut stream, &file, &write_time).await?;
println!(
"Query completed. Shuffle write time: {}. Rows: {}.",
write_time, stats.num_rows
);
if partition_count == 1 {
// stream the results from the query
// TODO remove hard-coded path
let file = format!("/tmp/raysql/stage_{}_part_0.arrow", stage_id);
println!("Executing query and writing results to {file}");
let stats = write_stream_to_disk(&mut stream, &file, &write_time).await?;
println!(
"Query completed. Shuffle write time: {}. Rows: {}.",
write_time, stats.num_rows
);
} else {
// we won't necessary produce output for every possible partition, so we
// create writers on demand
let mut writers: Vec<Option<IPCWriter>> = vec![];
for _ in 0..partition_count {
writers.push(None);
}

let mut partitioner = BatchPartitioner::try_new(partitioning, repart_time.clone())?;

let mut rows = 0;

while let Some(result) = stream.next().await {
let input_batch = result?;
rows += input_batch.num_rows();

println!(
"ShuffleWriterExec writing batch:\n{}",
pretty_format_batches(&[input_batch.clone()])?
);

//write_metrics.input_rows.add(input_batch.num_rows());

partitioner.partition(input_batch, |output_partition, output_batch| {
match &mut writers[output_partition] {
Some(w) => {
w.write(&output_batch)?;
}
None => {
// TODO remove hard-coded path
let path = format!(
"/tmp/raysql/stage_{}_part_{}.arrow",
stage_id, output_partition
);
let path = Path::new(&path);
println!("Writing results to {:?}", path);

let mut writer = IPCWriter::new(&path, stream.schema().as_ref())?;

writer.write(&output_batch)?;
writers[output_partition] = Some(writer);
}
}
Ok(())
})?;
}

for (i, w) in writers.iter_mut().enumerate() {
match w {
Some(w) => {
w.finish()?;
println!(
"Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.",
i,
w.path(),
w.num_batches,
w.num_rows,
w.num_bytes
);
}
None => {}
}
}

println!("finished processing stream with {rows} rows");
}

// create a dummy batch to return - later this could be metadata about the
// shuffle partitions that were written out
Expand Down

0 comments on commit 99ccc19

Please sign in to comment.