Skip to content

Commit

Permalink
[FEAT] Read from LanceDB (#2195)
Browse files Browse the repository at this point in the history
Adds the ability to read a Lance dataset:
https://lancedb.github.io/lance/read_and_write.html

---------

Co-authored-by: Jay Chia <jaychia94@gmail.com@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia committed May 1, 2024
1 parent e3756f5 commit c4928f8
Show file tree
Hide file tree
Showing 20 changed files with 553 additions and 84 deletions.
2 changes: 2 additions & 0 deletions daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_build_type() -> str:
read_json,
read_parquet,
read_sql,
read_lance,
)
from daft.series import Series
from daft.udf import udf
Expand All @@ -98,6 +99,7 @@ def get_build_type() -> str:
"read_iceberg",
"read_delta_lake",
"read_sql",
"read_lance",
"DataCatalogType",
"DataCatalogTable",
"DataFrame",
Expand Down
19 changes: 19 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,21 @@ class ScanTask:
Create a SQL Scan Task
"""
...
@staticmethod
def python_factory_func_scan_task(
module: str,
func_name: str,
func_args: tuple[Any, ...],
schema: PySchema,
num_rows: int | None,
size_bytes: int | None,
pushdowns: Pushdowns | None,
stats: PyTable | None,
) -> ScanTask:
"""
Create a Python factory function Scan Task
"""
...

class ScanOperatorHandle:
"""
Expand Down Expand Up @@ -699,6 +714,10 @@ class Pushdowns:
partition_filters: PyExpr | None
limit: int | None

def filter_required_column_names(self) -> list[str]:
"""List of field names that are required by the filter predicate."""
...

def read_parquet(
uri: str,
columns: list[str] | None = None,
Expand Down
73 changes: 2 additions & 71 deletions daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,18 @@
import logging
import os
from collections.abc import Iterator
from typing import Any
from urllib.parse import urlparse

from deltalake.table import DeltaTable

import daft
from daft.daft import (
AzureConfig,
FileFormatConfig,
GCSConfig,
IOConfig,
NativeStorageConfig,
ParquetSourceConfig,
Pushdowns,
S3Config,
ScanTask,
StorageConfig,
)
from daft.io.object_store_options import io_config_to_storage_options
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema

Expand All @@ -30,7 +24,7 @@
class DeltaLakeScanOperator(ScanOperator):
def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
super().__init__()
storage_options = _storage_config_to_storage_options(storage_config, table_uri)
storage_options = io_config_to_storage_options(storage_config.config.io_config, table_uri)
self._table = DeltaTable(table_uri, storage_options=storage_options)
self._storage_config = storage_config
self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow())
Expand Down Expand Up @@ -165,66 +159,3 @@ def can_absorb_limit(self) -> bool:

def can_absorb_select(self) -> bool:
return True


def _storage_config_to_storage_options(storage_config: StorageConfig, table_uri: str) -> dict[str, str] | None:
"""
Converts the Daft storage config to a storage options dict that deltalake/object_store
understands.
"""
config = storage_config.config
assert isinstance(config, NativeStorageConfig)
io_config = config.io_config
return _io_config_to_storage_options(io_config, table_uri)


def _io_config_to_storage_options(io_config: IOConfig, table_uri: str) -> dict[str, str] | None:
scheme = urlparse(table_uri).scheme
if scheme == "s3" or scheme == "s3a":
return _s3_config_to_storage_options(io_config.s3)
elif scheme == "gcs" or scheme == "gs":
return _gcs_config_to_storage_options(io_config.gcs)
elif scheme == "az" or scheme == "abfs":
return _azure_config_to_storage_options(io_config.azure)
else:
return None


def _s3_config_to_storage_options(s3_config: S3Config) -> dict[str, str]:
storage_options: dict[str, Any] = {}
if s3_config.region_name is not None:
storage_options["region"] = s3_config.region_name
if s3_config.endpoint_url is not None:
storage_options["endpoint_url"] = s3_config.endpoint_url
if s3_config.key_id is not None:
storage_options["access_key_id"] = s3_config.key_id
if s3_config.session_token is not None:
storage_options["session_token"] = s3_config.session_token
if s3_config.access_key is not None:
storage_options["secret_access_key"] = s3_config.access_key
if s3_config.use_ssl is not None:
storage_options["allow_http"] = "false" if s3_config.use_ssl else "true"
if s3_config.verify_ssl is not None:
storage_options["allow_invalid_certificates"] = "false" if s3_config.verify_ssl else "true"
if s3_config.connect_timeout_ms is not None:
storage_options["connect_timeout"] = str(s3_config.connect_timeout_ms) + "ms"
if s3_config.anonymous:
storage_options["skip_signature"] = "true"
return storage_options


def _azure_config_to_storage_options(azure_config: AzureConfig) -> dict[str, str]:
storage_options = {}
if azure_config.storage_account is not None:
storage_options["account_name"] = azure_config.storage_account
if azure_config.access_key is not None:
storage_options["access_key"] = azure_config.access_key
if azure_config.endpoint_url is not None:
storage_options["endpoint"] = azure_config.endpoint_url
if azure_config.use_ssl is not None:
storage_options["allow_http"] = "false" if azure_config.use_ssl else "true"
return storage_options


def _gcs_config_to_storage_options(_: GCSConfig) -> dict[str, str]:
return {}
2 changes: 2 additions & 0 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from daft.io._hudi import read_hudi
from daft.io._iceberg import read_iceberg
from daft.io._json import read_json
from daft.io._lance import read_lance
from daft.io._parquet import read_parquet
from daft.io._sql import read_sql
from daft.io.catalog import DataCatalogTable, DataCatalogType
Expand Down Expand Up @@ -42,6 +43,7 @@ def _set_linux_cert_paths():
"read_hudi",
"read_iceberg",
"read_delta_lake",
"read_lance",
"read_sql",
"IOConfig",
"S3Config",
Expand Down
126 changes: 126 additions & 0 deletions daft/io/_lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# isort: dont-add-import: from __future__ import annotations

from typing import TYPE_CHECKING, Iterator, List, Optional

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import IOConfig, Pushdowns, PyTable, ScanOperatorHandle, ScanTask
from daft.dataframe import DataFrame
from daft.io.object_store_options import io_config_to_storage_options
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.builder import LogicalPlanBuilder
from daft.logical.schema import Schema
from daft.table import Table

if TYPE_CHECKING:
import lance


def _lancedb_table_factory_function(
fragment: "lance.LanceFragment", required_columns: Optional[List[str]]
) -> Iterator["PyTable"]:
return (
Table.from_arrow_record_batches([rb], rb.schema)._table for rb in fragment.to_batches(columns=required_columns)
)


@PublicAPI
def read_lance(url: str, io_config: Optional["IOConfig"] = None) -> DataFrame:
"""Create a DataFrame from a LanceDB table
.. NOTE::
This function requires the use of `LanceDB <https://lancedb.github.io/lancedb/>`_, which is the Python
library for the LanceDB project.
Args:
url: URL to the LanceDB table (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
Returns:
DataFrame: a DataFrame with the schema converted from the specified LanceDB table
"""

try:
import lance
except ImportError as e:
raise ImportError(
"Unable to import the `lance` package, please ensure that Daft is installed with the lance extra dependency: `pip install getdaft[lance]`"
) from e

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_options = io_config_to_storage_options(io_config, url)

ds = lance.dataset(url, storage_options=storage_options)
iceberg_operator = LanceDBScanOperator(ds)

handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
return DataFrame(builder)


class LanceDBScanOperator(ScanOperator):
def __init__(self, ds: "lance.LanceDataset"):
self._ds = ds

def display_name(self) -> str:
return f"LanceDBScanOperator({self._ds.uri})"

def schema(self) -> Schema:
return Schema.from_pyarrow_schema(self._ds.schema)

def partitioning_keys(self) -> List[PartitionField]:
return []

def can_absorb_filter(self) -> bool:
return False

def can_absorb_limit(self) -> bool:
return False

def can_absorb_select(self) -> bool:
return False

def multiline_display(self) -> List[str]:
return [
self.display_name(),
f"Schema = {self.schema()}",
]

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
required_columns: Optional[List[str]]
if pushdowns.columns is None:
required_columns = None
else:
filter_required_column_names = pushdowns.filter_required_column_names()
required_columns = (
pushdowns.columns
if filter_required_column_names is None
else pushdowns.columns + filter_required_column_names
)

# TODO: figure out how to translate Pushdowns into LanceDB filters
filters = None
fragments = self._ds.get_fragments(filter=filters)
for i, fragment in enumerate(fragments):
# TODO: figure out how if we can get this metadata from LanceDB fragments cheaply
size_bytes = None
stats = None

# NOTE: `fragment.count_rows()` should result in 1 IO call for the data file
# (1 fragment = 1 data file) and 1 more IO call for the deletion file (if present).
# This could potentially be expensive to perform serially if there are thousands of files.
# Given that num_rows isn't leveraged for much at the moment, and without statistics
# we will probably end up materializing the data anyways for any operations, we leave this
# as None.
num_rows = None

yield ScanTask.python_factory_func_scan_task(
module=_lancedb_table_factory_function.__module__,
func_name=_lancedb_table_factory_function.__name__,
func_args=(fragment, required_columns),
schema=self.schema()._schema,
num_rows=num_rows,
size_bytes=size_bytes,
pushdowns=pushdowns,
stats=stats,
)
64 changes: 64 additions & 0 deletions daft/io/object_store_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from urllib.parse import urlparse

from daft.io import AzureConfig, GCSConfig, IOConfig, S3Config


def io_config_to_storage_options(io_config: IOConfig, table_uri: str) -> dict[str, str] | None:
"""
Converts the Daft IOConfig to a storage options dict that the object_store crate
understands. The object_store crate is used by many Rust-backed Python libraries such as
delta-rs and lance.
This function takes as input the table_uri, which it uses to determine the backend to be used.
"""
scheme = urlparse(table_uri).scheme
if scheme == "s3" or scheme == "s3a":
return _s3_config_to_storage_options(io_config.s3)
elif scheme == "gcs" or scheme == "gs":
return _gcs_config_to_storage_options(io_config.gcs)
elif scheme == "az" or scheme == "abfs":
return _azure_config_to_storage_options(io_config.azure)
else:
return None


def _s3_config_to_storage_options(s3_config: S3Config) -> dict[str, str]:
storage_options: dict[str, str] = {}
if s3_config.region_name is not None:
storage_options["region"] = s3_config.region_name
if s3_config.endpoint_url is not None:
storage_options["endpoint_url"] = s3_config.endpoint_url
if s3_config.key_id is not None:
storage_options["access_key_id"] = s3_config.key_id
if s3_config.session_token is not None:
storage_options["session_token"] = s3_config.session_token
if s3_config.access_key is not None:
storage_options["secret_access_key"] = s3_config.access_key
if s3_config.use_ssl is not None:
storage_options["allow_http"] = "false" if s3_config.use_ssl else "true"
if s3_config.verify_ssl is not None:
storage_options["allow_invalid_certificates"] = "false" if s3_config.verify_ssl else "true"
if s3_config.connect_timeout_ms is not None:
storage_options["connect_timeout"] = str(s3_config.connect_timeout_ms) + "ms"
if s3_config.anonymous:
storage_options["skip_signature"] = "true"
return storage_options


def _azure_config_to_storage_options(azure_config: AzureConfig) -> dict[str, str]:
storage_options = {}
if azure_config.storage_account is not None:
storage_options["account_name"] = azure_config.storage_account
if azure_config.access_key is not None:
storage_options["access_key"] = azure_config.access_key
if azure_config.endpoint_url is not None:
storage_options["endpoint"] = azure_config.endpoint_url
if azure_config.use_ssl is not None:
storage_options["allow_http"] = "false" if azure_config.use_ssl else "true"
return storage_options


def _gcs_config_to_storage_options(_: GCSConfig) -> dict[str, str]:
return {}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ deltalake = ["deltalake"]
gcp = []
hudi = ["pyarrow >= 8.0.0"]
iceberg = ["pyiceberg >= 0.4.0", "packaging"]
lance = ["lancedb"]
numpy = ["numpy"]
pandas = ["pandas"]
ray = [
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pyarrow==12
ray[data, client]==2.7.1; python_version < '3.8'
ray[data, client]==2.10.0; python_version >= '3.8'

# Lance
lancedb>=0.6.10; python_version >= '3.8'

#Iceberg
pyiceberg==0.6.0; python_version >= '3.8'
Expand Down
Loading

0 comments on commit c4928f8

Please sign in to comment.