Asif Rahman

ChartQL - A token-efficient SQL extension for generating Plotly charts

Use LLMs to generate Plotly chart specifications directly from SQL queries with embedded chart definitions using a small DSL for chart configuration.

Most LLM-generated chart code is wasteful (“token inefficient”). When you ask a model to create a Plotly chart, it writes mostly boilerplate with a lot of output tokens for what is fundamentally a simple mapping: take these columns, plot them as this chart type, apply these visual options.

ChartQL compresses that entire specification into a few lines. The model only needs to produce a SQL query with a PLOT AS clause, and the execution engine handles the rest. A experiments I found typical Plotly chart specification runs 400-600 tokens. The equivalent ChartQL is 40-80, which is a 10x reduction.

This is especially important in agentic systems where we want to reduce the number of tool calls. The standard flow would be to write SQL to get the data, then write the chart specification. With ChartQL, a single tool call handles both the query and the chart definition. Fewer tool calls means fewer round trips, lower latency, and less opportunity for the model to introduce errors between steps.

Plotly has an excellent reference doc and a machine readable specification that describes all the options for layouts and formatting a chart. But with so many options, a full Plotly specification has many ways it can go wrong, mismatched list lengths, incorrect nesting of layout dicts, and wrong trace types. ChartQL’s grammar is small enough that a model can reliably produce valid output, and when it doesn’t, the parser gives precise error locations in the chart definition, so it’s easy to fix.

ChartQL allows you to embed Plotly chart definitions directly within SQL queries using the PLOT AS syntax.

You can customize chart appearance using Plotly layout options using Plotly’s magic underscore notation which makes it easier to work with nested properties. Supported plot types:

  1. line - Line charts with markers
  2. bar - Bar charts (vertical/horizontal)
  3. scatter - Scatter plots
  4. pie - Pie charts
  5. heatmap - Heat maps
  6. histogram - Histogram plots
  7. box - Box plots
SELECT end_date, n_reports FROM metrics
WHERE measure = 'active_users' AND adjustment = "total"
PLOT AS line(
    x=end_date,
    y=n_reports,
    trace_marker_color='blue',
    trace_marker_size=8,
    layout_title_text='Active Users Over Time',
    layout_xaxis_type='date',
    layout_xaxis_title='Date',
    layout_yaxis_title='Number of Active Users',
    layout_height=400,
    layout_width=700,
    layout_margin_b=50
)

Overall architecture:

import sqlparse
import polars as pl
from sqlparse.tokens import Text
from lark import Lark, Transformer
from lark.exceptions import UnexpectedToken
from typing import Dict, Any, Tuple, List, Union, Callable

class ChartQLException(Exception):
    """Base exception for ChartQL parsing and processing errors."""
    pass

class ChartQLParseError(ChartQLException):
    """Enhanced parsing error with visual context showing the error location."""

    def __init__(self, original_error: UnexpectedToken, chart_spec: str):
        self.original_error = original_error
        self.chart_spec = chart_spec
        super().__init__(self._format_error())

    def _format_error(self) -> str:
        lines = self.chart_spec.splitlines()
        error_line_idx = self.original_error.line - 1  # Convert to 0-based
        error_col = self.original_error.column - 1

        # Get the original error message
        message = str(self.original_error)

        # Add visual context if we can locate the line
        if 0 <= error_line_idx < len(lines):
            line_content = lines[error_line_idx]
            pointer = " " * error_col + "^"
            message += f"\n{line_content}\n{pointer}"

        return message


def find_plot_as_position(tokens: List) -> int:
    """Find the position of PLOT AS clause in token list. Returns -1 if not found."""
    for i in range(len(tokens) - 2):
        if (
            tokens[i].value.upper() == "PLOT"
            and tokens[i + 1].ttype in (None, Text.Whitespace)
            and tokens[i + 2].value.upper() == "AS"
        ):
            return i
    return -1


def split_chartql_query(query: str) -> Tuple[str, str]:
    """Split ChartQL query into SQL and chart specification parts."""
    tokens = list(sqlparse.parse(query)[0].flatten())
    plot_as_pos = find_plot_as_position(tokens)

    if plot_as_pos == -1:
        raise ChartQLException("No PLOT AS clause found")

    sql_tokens = tokens[:plot_as_pos]
    chart_tokens = tokens[plot_as_pos + 3 :]

    sql_query = "".join(token.value for token in sql_tokens).strip()
    chart_spec = "".join(token.value for token in chart_tokens).strip()

    return sql_query, chart_spec


def has_chartql_plot(query: str) -> bool:
    """Check if query contains a PLOT AS clause."""
    tokens = list(sqlparse.parse(query)[0].flatten())
    return find_plot_as_position(tokens) != -1


chartql_grammar = """
    start: chart_type "(" parameter_list? ")"
    chart_type: CHART_TYPE
    parameter_list: parameter ("," parameter)*
    parameter: key "=" value
    key: CNAME
    value: string | number | boolean | list | column_ref
    list: "[" value ("," value)* "]"
    string: SINGLE_STRING | DOUBLE_STRING
    number: SIGNED_NUMBER
    boolean: "True" | "False" | "true" | "false"
    column_ref: CNAME
    
    CHART_TYPE: "line" | "bar" | "scatter" | "pie" | "heatmap" | "histogram" | "box"
    SINGLE_STRING: /'[^']*'/
    DOUBLE_STRING: /"[^"]*"/
    COMMENT: /--[^\\n\\r]*/
    
    %import common.SIGNED_NUMBER
    %import common.CNAME
    %import common.WS
    %ignore WS
    %ignore COMMENT
"""


class ChartQLTransformer(Transformer):
    def start(self, items):
        chart_type, params = items[0], items[1] if len(items) > 1 else {}
        return {"chart_type": chart_type, "parameters": params}

    def chart_type(self, items):
        return str(items[0])

    def parameter_list(self, items):
        return dict(items)

    def parameter(self, items):
        return (items[0], items[1])

    def key(self, items):
        return str(items[0])

    def list(self, items):
        return list(items)

    def string(self, items):
        return str(items[0])[1:-1]

    def number(self, items):
        val = str(items[0])
        return int(val) if "." not in val else float(val)

    def boolean(self, items):
        if not items:
            return False
        return str(items[0]).lower() == "true"

    def column_ref(self, items):
        return str(items[0])

    def value(self, items):
        return items[0]


def parse_chartql(chart_spec: str) -> Dict[str, Any]:
    """Parse chart specification into structured dictionary."""
    parser = Lark(chartql_grammar, parser="lalr", transformer=ChartQLTransformer())
    try:
        return parser.parse(chart_spec)
    except UnexpectedToken as e:
        raise ChartQLParseError(e, chart_spec) from None


DataInput = Union[List[Dict[str, Any]], pl.DataFrame]


class ChartSpecGenerator:
    def __init__(self):
        self.chart_defaults = {
            "line": {"type": "scatter", "mode": "lines+markers"},
            "bar": {"type": "bar"},
            "scatter": {"type": "scatter", "mode": "markers"},
            "pie": {"type": "pie"},
            "heatmap": {"type": "heatmap"},
            "histogram": {"type": "histogram"},
            "box": {"type": "box"},
        }

    def generate_spec(self, parsed_chart: Dict[str, Any], data: DataInput) -> Dict[str, Any]:
        chart_type = parsed_chart["chart_type"]
        parameters = parsed_chart["parameters"]

        layout_params = self._extract_prefixed_params(parameters, "layout_")
        config_params = self._extract_prefixed_params(parameters, "config_")
        trace_params = self._extract_prefixed_params(parameters, "trace_")
        data_params = self._extract_data_params(parameters)

        traces = self._build_traces(chart_type, data_params, trace_params, data)
        layout = self._build_layout(layout_params)
        config = self._build_config(config_params)

        return {"data": traces, "layout": layout, "config": config}

    def _extract_prefixed_params(self, parameters: Dict[str, Any], prefix: str) -> Dict[str, Any]:
        result = {}
        for key, value in parameters.items():
            if key.startswith(prefix):
                new_key = key[len(prefix) :]
                result[new_key] = value
        return result

    def _extract_data_params(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        data_fields = {"x", "y", "z", "color", "values", "labels", "text", "size"}
        result = {}
        for key, value in parameters.items():
            if key in data_fields:
                result[key] = value
        return result

    def _build_traces(
        self, chart_type: str, data_params: Dict[str, Any], trace_params: Dict[str, Any], data: DataInput
    ) -> List[Dict[str, Any]]:
        base_trace = self.chart_defaults[chart_type].copy()

        primary_data_col = self._get_primary_data_column(chart_type)
        data_columns = data_params.get(primary_data_col, [])

        if not isinstance(data_columns, list):
            data_columns = [data_columns]

        traces = []
        for i, col in enumerate(data_columns):
            trace = base_trace.copy()

            self._apply_data_mappings(trace, chart_type, data_params, data, i)
            self._apply_trace_params(trace, trace_params, i)

            trace["name"] = col if len(data_columns) > 1 else chart_type
            traces.append(trace)

        return traces

    def _apply_trace_params(self, trace: Dict[str, Any], trace_params: Dict[str, Any], trace_index: int):
        for key, value in trace_params.items():
            if isinstance(value, list) and len(value) > trace_index:
                self._set_nested_dict(trace, key, value[trace_index])
            elif not isinstance(value, list):
                self._set_nested_dict(trace, key, value)

    def _apply_data_mappings(
        self, trace: Dict[str, Any], chart_type: str, data_params: Dict[str, Any], data: DataInput, trace_index: int
    ):
        data_field_mappings = {
            "line": ["x", "y"],
            "scatter": ["x", "y", "size", "color", "text"],
            "bar": ["x", "y"],
            "pie": ["values", "labels"],
            "heatmap": ["x", "y", "z"],
            "histogram": ["x"],
            "box": ["x", "y"],
        }

        for field in data_field_mappings.get(chart_type, []):
            if field in data_params:
                col_ref = data_params[field]
                if isinstance(col_ref, list):
                    primary_col = self._get_primary_data_column(chart_type)
                    if field == primary_col and len(col_ref) > trace_index:
                        trace[field] = self._get_column_data(data, col_ref[trace_index])
                else:
                    trace[field] = self._get_column_data(data, col_ref)

    def _get_primary_data_column(self, chart_type: str) -> str:
        primary_cols = {"line": "y", "scatter": "y", "bar": "y", "pie": "values", "histogram": "x", "box": "y"}
        return primary_cols.get(chart_type, "y")

    def _get_column_data(self, data: DataInput, column: str) -> List[Any]:
        """Extract column data as list from either Polars DataFrame or list of dicts."""
        if pl and isinstance(data, pl.DataFrame):
            return data[column].to_list()
        elif isinstance(data, list):
            return [row.get(column) for row in data]
        else:
            raise ChartQLException(f"Unsupported data type: {type(data)}")

    def _get_columns(self, data: DataInput) -> List[str]:
        """Get column names from data."""
        if pl and isinstance(data, pl.DataFrame):
            return data.columns
        elif isinstance(data, list) and data:
            return list(data[0].keys())
        else:
            return []

    def _build_layout(self, layout_params: Dict[str, Any]) -> Dict[str, Any]:
        layout = {}
        for key, value in layout_params.items():
            self._set_nested_dict(layout, key, value)
        return layout

    def _build_config(self, config_params: Dict[str, Any]) -> Dict[str, Any]:
        return config_params.copy()

    def _set_nested_dict(self, d: Dict[str, Any], key: str, value: Any):
        if not key:
            return
        keys = key.split("_")
        if not keys:
            return
        current = d
        for k in keys[:-1]:
            if k not in current:
                current[k] = {}
            current = current[k]
        current[keys[-1]] = value


class DataMapper:
    def __init__(self, data: DataInput):
        self.data = data
        self.columns = set(self._get_columns(data))

    def resolve_column_references(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        resolved = {}
        for key, value in parameters.items():
            if self._is_data_field(key):
                resolved[key] = self._resolve_value(value)
            else:
                resolved[key] = value
        return resolved

    def _is_data_field(self, key: str) -> bool:
        data_fields = {"x", "y", "z", "color", "values", "labels", "text", "size"}
        return key in data_fields

    def _resolve_value(self, value: Any) -> Any:
        if isinstance(value, str) and value in self.columns:
            return value
        elif isinstance(value, list):
            return [self._resolve_value(v) for v in value]
        else:
            return value

    def validate_columns(self, parameters: Dict[str, Any]):
        missing_columns = []
        for key, value in parameters.items():
            if self._is_data_field(key):
                columns = self._extract_column_names(value)
                for col in columns:
                    if col not in self.columns:
                        missing_columns.append(col)

        if missing_columns:
            raise ChartQLException(f"Columns not found in data: {missing_columns}")

    def _extract_column_names(self, value: Any) -> List[str]:
        if isinstance(value, str):
            return [value] if value in self.columns else []
        elif isinstance(value, list):
            columns = []
            for v in value:
                columns.extend(self._extract_column_names(v))
            return columns
        else:
            return []

    def _get_columns(self, data: DataInput) -> List[str]:
        """Get column names from data."""
        if pl and isinstance(data, pl.DataFrame):
            return data.columns
        elif isinstance(data, list) and data:
            return list(data[0].keys())
        else:
            return []


def process_chartql_query(query: str, sql_executor: Callable) -> Dict[str, Any]:
    """Process a complete ChartQL query and return Plotly specification."""
    sql_query, chart_spec = split_chartql_query(query)

    data = sql_executor(sql_query)

    parsed_chart = parse_chartql(chart_spec)

    data_mapper = DataMapper(data)
    data_mapper.validate_columns(parsed_chart["parameters"])
    resolved_params = data_mapper.resolve_column_references(parsed_chart["parameters"])

    spec_generator = ChartSpecGenerator()
    plotly_spec = spec_generator.generate_spec(
        {"chart_type": parsed_chart["chart_type"], "parameters": resolved_params}, data
    )

    return plotly_spec

#LLM