Skip to content

Commit

Permalink
Decimal support
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey committed Mar 23, 2024
1 parent 70a0262 commit bb885c0
Show file tree
Hide file tree
Showing 14 changed files with 249 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Read [Apache ORC](https://orc.apache.org/) in Rust.
| Boolean || | Boolean |
| TinyInt || | Int8 |
| Binary || | Binary |
| Decimal | | | |
| Decimal | | | Decimal128 |
| Date || | Date32 |
| Timestamp || | Timestamp(Nanosecond,_) |
| Timestamp instant || | |
Expand Down
103 changes: 103 additions & 0 deletions src/arrow_reader/decoder/decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::cmp::Ordering;

use crate::arrow_reader::column::get_present_vec;
use crate::arrow_reader::decoder::DecimalArrayDecoder;
use crate::error::Result;
use crate::proto::stream::Kind;
use crate::reader::decode::decimal::UnboundedVarintStreamDecoder;
use crate::stripe::Stripe;
use crate::{arrow_reader::column::Column, reader::decode::get_rle_reader};

use super::ArrayBatchDecoder;

pub fn new_decimal_decoder(
column: &Column,
stripe: &Stripe,
precision: u32,
fixed_scale: u32,
) -> Result<Box<dyn ArrayBatchDecoder>> {
let varint_iter = stripe.stream_map.get(column, Kind::Data)?;
let varint_iter = Box::new(UnboundedVarintStreamDecoder::new(
varint_iter,
stripe.number_of_rows,
));
let varint_iter = varint_iter as Box<dyn Iterator<Item = Result<i128>> + Send>;

// Scale is specified on a per varint basis (in addition to being encoded in the type)
let scale_iter = stripe.stream_map.get(column, Kind::Secondary)?;
let scale_iter = get_rle_reader::<i32, _>(column, scale_iter)?;

let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let iter = DecimalScaleRepairIter {
varint_iter,
scale_iter,
fixed_scale,
};
let iter = Box::new(iter);

Ok(Box::new(DecimalArrayDecoder::new(
precision as u8,
fixed_scale as i8,
iter,
present,
)))
}

/// This iter fixes the scales of the varints decoded as scale is specified on a per
/// varint basis, and needs to align with type specified scale
struct DecimalScaleRepairIter {
varint_iter: Box<dyn Iterator<Item = Result<i128>> + Send>,
scale_iter: Box<dyn Iterator<Item = Result<i32>> + Send>,
fixed_scale: u32,
}

impl DecimalScaleRepairIter {
#[inline]
fn next_helper(&mut self, varint: Result<i128>, scale: Result<i32>) -> Result<Option<i128>> {
let varint = varint?;
let scale = scale?;
Ok(Some(fix_i128_scale(varint, self.fixed_scale, scale)))
}
}

impl Iterator for DecimalScaleRepairIter {
type Item = Result<i128>;

fn next(&mut self) -> Option<Self::Item> {
let varint = self.varint_iter.next()?;
let scale = self.scale_iter.next()?;
self.next_helper(varint, scale).transpose()
}
}

fn fix_i128_scale(i: i128, fixed_scale: u32, varying_scale: i32) -> i128 {
// TODO: Verify with C++ impl in ORC repo, which does this cast
// Not sure why scale stream can be signed if it gets casted to unsigned anyway
// https://github.com/apache/orc/blob/0014bec1e4cdd1206f5bae4f5c2000b9300c6eb1/c%2B%2B/src/ColumnReader.cc#L1459-L1476
let varying_scale = varying_scale as u32;
match fixed_scale.cmp(&varying_scale) {
Ordering::Less => {
// fixed_scale < varying_scale
// Current scale of number is greater than scale of the array type
// So need to divide to align the scale
// TODO: this differs from C++ implementation, need to verify
let scale_factor = varying_scale - fixed_scale;
// TODO: replace with lookup table?
let scale_factor = 10_i128.pow(scale_factor);
i / scale_factor
}
Ordering::Equal => i,
Ordering::Greater => {
// fixed_scale > varying_scale
// Current scale of number is smaller than scale of the array type
// So need to multiply to align the scale
// TODO: this differs from C++ implementation, need to verify
let scale_factor = fixed_scale - varying_scale;
// TODO: replace with lookup table?
let scale_factor = 10_i128.pow(scale_factor);
i * scale_factor
}
}
}
50 changes: 47 additions & 3 deletions src/arrow_reader/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, PrimitiveBuilder};
use arrow::buffer::NullBuffer;
use arrow::datatypes::{ArrowPrimitiveType, UInt64Type};
use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type, UInt64Type};
use arrow::datatypes::{
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
TimestampNanosecondType,
};
use arrow::record_batch::RecordBatch;
use snafu::ResultExt;

use crate::error::{self, Result};
use crate::error::{self, ArrowSnafu, Result};
use crate::proto::stream::Kind;
use crate::reader::decode::boolean_rle::BooleanIter;
use crate::reader::decode::byte_rle::ByteRleIter;
Expand All @@ -19,6 +19,7 @@ use crate::reader::decode::get_rle_reader;
use crate::schema::DataType;
use crate::stripe::Stripe;

use self::decimal::new_decimal_decoder;
use self::list::ListArrayDecoder;
use self::map::MapArrayDecoder;
use self::string::{new_binary_decoder, new_string_decoder};
Expand All @@ -27,6 +28,7 @@ use self::struct_decoder::StructArrayDecoder;
use super::column::timestamp::TimestampIterator;
use super::column::{get_present_vec, Column};

mod decimal;
mod list;
mod map;
mod string;
Expand Down Expand Up @@ -106,6 +108,46 @@ type Float64ArrayDecoder = PrimitiveArrayDecoder<Float64Type>;
type TimestampArrayDecoder = PrimitiveArrayDecoder<TimestampNanosecondType>;
type DateArrayDecoder = PrimitiveArrayDecoder<Date32Type>; // TODO: does ORC encode as i64 or i32?

/// Wrapper around PrimitiveArrayDecoder to allow specifying the precision and scale
/// of the output decimal array.
struct DecimalArrayDecoder {
precision: u8,
scale: i8,
inner: PrimitiveArrayDecoder<Decimal128Type>,
}

impl DecimalArrayDecoder {
pub fn new(
precision: u8,
scale: i8,
iter: Box<dyn Iterator<Item = Result<i128>> + Send>,
present: Option<Box<dyn Iterator<Item = bool> + Send>>,
) -> Self {
let inner = PrimitiveArrayDecoder::<Decimal128Type>::new(iter, present);
Self {
precision,
scale,
inner,
}
}
}

impl ArrayBatchDecoder for DecimalArrayDecoder {
fn next_batch(
&mut self,
batch_size: usize,
parent_present: Option<&[bool]>,
) -> Result<ArrayRef> {
let array = self
.inner
.next_primitive_batch(batch_size, parent_present)?
.with_precision_and_scale(self.precision, self.scale)
.context(ArrowSnafu)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
}
}

struct BooleanArrayDecoder {
iter: Box<dyn Iterator<Item = Result<bool>> + Send>,
present: Option<Box<dyn Iterator<Item = bool> + Send>>,
Expand Down Expand Up @@ -331,7 +373,9 @@ pub fn array_decoder_factory(
new_string_decoder(column, stripe)?
}
DataType::Binary { .. } => new_binary_decoder(column, stripe)?,
DataType::Decimal { .. } => todo!(),
DataType::Decimal {
precision, scale, ..
} => new_decimal_decoder(column, stripe, *precision, *scale)?,
DataType::Timestamp { .. } => {
let data = stripe.stream_map.get(column, Kind::Data)?;
let data = get_rle_reader(column, data)?;
Expand Down
29 changes: 29 additions & 0 deletions src/reader/decode/decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::io::Read;

use crate::{error::Result, reader::decode::util::read_varint_zigzagged};

/// Read stream of zigzag encoded varints as i128 (unbound).
pub struct UnboundedVarintStreamDecoder<R: Read> {
reader: R,
remaining: usize,
}

impl<R: Read> UnboundedVarintStreamDecoder<R> {
pub fn new(reader: R, expected_length: usize) -> Self {
Self {
reader,
remaining: expected_length,
}
}
}

impl<R: Read> Iterator for UnboundedVarintStreamDecoder<R> {
type Item = Result<i128>;

fn next(&mut self) -> Option<Self::Item> {
(self.remaining > 0).then(|| {
self.remaining -= 1;
read_varint_zigzagged::<i128, _>(&mut self.reader)
})
}
}
29 changes: 29 additions & 0 deletions src/reader/decode/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use self::util::{signed_msb_decode, signed_zigzag_decode};

pub mod boolean_rle;
pub mod byte_rle;
pub mod decimal;
pub mod float;
pub mod rle_v1;
pub mod rle_v2;
Expand Down Expand Up @@ -221,3 +222,31 @@ impl NInt for u64 {
Self::from_be_bytes(b)
}
}

// This impl is used only for varint decoding.
// Hence some methods are left unimplemented since they are not used.
// TODO: maybe split NInt into traits for the specific use case
// - patched base decoding
// - varint decoding
// - etc.
impl NInt for i128 {
type Bytes = [u8; 16];
const BYTE_SIZE: usize = 16;

fn from_u64(_u: u64) -> Self {
unimplemented!()
}

fn from_u8(u: u8) -> Self {
u as Self
}

fn from_be_bytes(_b: Self::Bytes) -> Self {
unimplemented!()
}

#[inline]
fn zigzag_decode(self) -> Self {
signed_zigzag_decode(self)
}
}
2 changes: 1 addition & 1 deletion src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl DataType {
DataType::Binary { .. } => ArrowDataType::Binary,
DataType::Decimal {
precision, scale, ..
} => ArrowDataType::Decimal128(*precision as u8, *scale as i8),
} => ArrowDataType::Decimal128(*precision as u8, *scale as i8), // TODO: safety of cast?
DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::TimestampWithLocalTimezone { .. } => {
// TODO: get writer timezone
Expand Down
Binary file modified tests/basic/data/alltypes.lz4.orc
Binary file not shown.
Binary file modified tests/basic/data/alltypes.lzo.orc
Binary file not shown.
Binary file modified tests/basic/data/alltypes.none.orc
Binary file not shown.
Binary file modified tests/basic/data/alltypes.snappy.orc
Binary file not shown.
Binary file modified tests/basic/data/alltypes.zlib.orc
Binary file not shown.
Binary file modified tests/basic/data/alltypes.zstd.orc
Binary file not shown.
44 changes: 23 additions & 21 deletions tests/basic/data/generate_orc.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
import shutil
import glob
from datetime import date
from decimal import Decimal
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder.getOrCreate()

# TODO: int8, char, varchar, decimal, timestamp, struct, list, map, union
df = spark.createDataFrame(
[ # bool, int16, int32, int64, float32, float64, binary, utf8, date32
( None, None, None, None, None, None, None, None, None),
( True, 0, 0, 0, 0.0, 0.0, "".encode(), "", date(1970, 1, 1)),
(False, 1, 1, 1, 1.0, 1.0, "a".encode(), "a", date(1970, 1, 2)),
(False, -1, -1, -1, -1.0, -1.0, " ".encode(), " ", date(1969, 12, 31)),
( True, (1 << 15) - 1, (1 << 31) - 1, (1 << 63) - 1, float("inf"), float("inf"), "encode".encode(), "encode", date(9999, 12, 31)),
( True, -(1 << 15), -(1 << 31), -(1 << 63), float("-inf"), float("-inf"), "decode".encode(), "decode", date(1582, 10, 15)),
( True, 50, 50, 50, 3.1415927, 3.14159265359, "大熊和奏".encode(), "大熊和奏", date(1582, 10, 16)),
( True, 51, 51, 51, -3.1415927, -3.14159265359, "斉藤朱夏".encode(), "斉藤朱夏", date(2000, 1, 1)),
( True, 52, 52, 52, 1.1, 1.1, "鈴原希実".encode(), "鈴原希実", date(3000, 12, 31)),
(False, 53, 53, 53, -1.1, -1.1, "🤔".encode(), "🤔", date(1900, 1, 1)),
( None, None, None, None, None, None, None, None, None),
[ # bool, int16, int32, int64, float32, float64, binary, utf8, date32, decimal
( None, None, None, None, None, None, None, None, None, None),
( True, 0, 0, 0, 0.0, 0.0, "".encode(), "", date(1970, 1, 1), Decimal(0)),
(False, 1, 1, 1, 1.0, 1.0, "a".encode(), "a", date(1970, 1, 2), Decimal(1)),
(False, -1, -1, -1, -1.0, -1.0, " ".encode(), " ", date(1969, 12, 31), Decimal(-1)),
( True, (1 << 15) - 1, (1 << 31) - 1, (1 << 63) - 1, float("inf"), float("inf"), "encode".encode(), "encode", date(9999, 12, 31), Decimal(123456789.12345)),
( True, -(1 << 15), -(1 << 31), -(1 << 63), float("-inf"), float("-inf"), "decode".encode(), "decode", date(1582, 10, 15), Decimal(-999999999.99999)),
( True, 50, 50, 50, 3.1415927, 3.14159265359, "大熊和奏".encode(), "大熊和奏", date(1582, 10, 16), Decimal(-31256.123)),
( True, 51, 51, 51, -3.1415927, -3.14159265359, "斉藤朱夏".encode(), "斉藤朱夏", date(2000, 1, 1), Decimal(1241000)),
( True, 52, 52, 52, 1.1, 1.1, "鈴原希実".encode(), "鈴原希実", date(3000, 12, 31), Decimal(1.1)),
(False, 53, 53, 53, -1.1, -1.1, "🤔".encode(), "🤔", date(1900, 1, 1), Decimal(0.99999)),
( None, None, None, None, None, None, None, None, None, None),
],
StructType(
[
StructField("boolean", BooleanType()),
StructField( "int16", ShortType()),
StructField( "int32", IntegerType()),
StructField( "int64", LongType()),
StructField("float32", FloatType()),
StructField("float64", DoubleType()),
StructField( "binary", BinaryType()),
StructField( "utf8", StringType()),
StructField( "date32", DateType()),
StructField("boolean", BooleanType()),
StructField( "int16", ShortType()),
StructField( "int32", IntegerType()),
StructField( "int64", LongType()),
StructField("float32", FloatType()),
StructField("float64", DoubleType()),
StructField( "binary", BinaryType()),
StructField( "utf8", StringType()),
StructField( "date32", DateType()),
StructField("decimal", DecimalType(15, 5)),
]
),
).coalesce(1)
Expand Down
Loading

0 comments on commit bb885c0

Please sign in to comment.