Module library.sources
Expand source code
import os
from osgeo import gdal, ogr
from .utils import parse_engine
def format_field_names(dataset: gdal.Dataset, fields: list):
"""
dataset: Given source data source, usually a local file / s3 url
fields: a list of predefined field names
If we have a list of new field names, then rename fields with "fields"
otherwise, change all field names to lower case connected by underscore
"""
assert dataset, "dataset: gdal.Dataset shouldn't be None"
layer = dataset.GetLayer(0)
layerDefn = layer.GetLayerDefn()
if len(fields) == 0:
for i in range(layerDefn.GetFieldCount()):
fieldDefn = layerDefn.GetFieldDefn(i)
fieldName = fieldDefn.GetName()
fieldDefn.SetName(fieldName.replace(" ", "_").lower())
else:
for i in range(len(fields)):
fieldDefn = layerDefn.GetFieldDefn(i)
fieldDefn.SetName(fields[i])
return dataset
def get_allowed_drivers(url: str) -> list:
"""
Returns allowed drivers for OpenEx
given the file type of [url]
"""
allowed_drivers = [
gdal.GetDriver(i).GetDescription() for i in range(gdal.GetDriverCount())
]
_, extension = os.path.splitext(os.path.basename(url))
if extension == ".csv":
allowed_drivers = [driver for driver in allowed_drivers if "JSON" not in driver]
return allowed_drivers
def postgres_source(url: str) -> gdal.Dataset:
"""
url: postgres connection string
e.g. postgresql://username:password@host:port/database
"""
parsed = parse_engine(url)
return gdal.OpenEx(parsed, gdal.OF_VECTOR)
def generic_source(path: str, options: list = [], fields: list = []) -> gdal.Dataset:
"""
path: filepath, http url or s3 file url
e.g.
- s3://edm-recipes/2020-01-02/some-file.csv
- https://some.website.com/file.csv
- /local/fodler/abc.csv
"""
allowed_drivers = get_allowed_drivers(path)
dataset = gdal.OpenEx(
path, gdal.OF_VECTOR, open_options=options, allowed_drivers=allowed_drivers
)
assert dataset, f"{path} is invalid"
dataset = format_field_names(dataset, fields)
return dataset
Functions
def format_field_names(dataset: osgeo.gdal.Dataset, fields: list)
-
dataset: Given source data source, usually a local file / s3 url fields: a list of predefined field names
If we have a list of new field names, then rename fields with "fields" otherwise, change all field names to lower case connected by underscore
Expand source code
def format_field_names(dataset: gdal.Dataset, fields: list): """ dataset: Given source data source, usually a local file / s3 url fields: a list of predefined field names If we have a list of new field names, then rename fields with "fields" otherwise, change all field names to lower case connected by underscore """ assert dataset, "dataset: gdal.Dataset shouldn't be None" layer = dataset.GetLayer(0) layerDefn = layer.GetLayerDefn() if len(fields) == 0: for i in range(layerDefn.GetFieldCount()): fieldDefn = layerDefn.GetFieldDefn(i) fieldName = fieldDefn.GetName() fieldDefn.SetName(fieldName.replace(" ", "_").lower()) else: for i in range(len(fields)): fieldDefn = layerDefn.GetFieldDefn(i) fieldDefn.SetName(fields[i]) return dataset
def generic_source(path: str, options: list = [], fields: list = []) ‑> osgeo.gdal.Dataset
-
path: filepath, http url or s3 file url e.g. - s3://edm-recipes/2020-01-02/some-file.csv - https://some.website.com/file.csv - /local/fodler/abc.csv
Expand source code
def generic_source(path: str, options: list = [], fields: list = []) -> gdal.Dataset: """ path: filepath, http url or s3 file url e.g. - s3://edm-recipes/2020-01-02/some-file.csv - https://some.website.com/file.csv - /local/fodler/abc.csv """ allowed_drivers = get_allowed_drivers(path) dataset = gdal.OpenEx( path, gdal.OF_VECTOR, open_options=options, allowed_drivers=allowed_drivers ) assert dataset, f"{path} is invalid" dataset = format_field_names(dataset, fields) return dataset
def get_allowed_drivers(url: str) ‑> list
-
Returns allowed drivers for OpenEx given the file type of [url]
Expand source code
def get_allowed_drivers(url: str) -> list: """ Returns allowed drivers for OpenEx given the file type of [url] """ allowed_drivers = [ gdal.GetDriver(i).GetDescription() for i in range(gdal.GetDriverCount()) ] _, extension = os.path.splitext(os.path.basename(url)) if extension == ".csv": allowed_drivers = [driver for driver in allowed_drivers if "JSON" not in driver] return allowed_drivers
def postgres_source(url: str) ‑> osgeo.gdal.Dataset
-
url: postgres connection string e.g. postgresql://username:password@host:port/database
Expand source code
def postgres_source(url: str) -> gdal.Dataset: """ url: postgres connection string e.g. postgresql://username:password@host:port/database """ parsed = parse_engine(url) return gdal.OpenEx(parsed, gdal.OF_VECTOR)