Shizu0n's picture
feat: double-quoted comparison literals for schema validation
35d7c39
import html
import re
import unicodedata
import sqlparse
SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"}
LANGUAGE_FALLBACK_QUESTION = "write a SELECT query for this table."
FOREIGN_SQL_MARKERS = {
"alumnos", "alunos", "barato", "caro", "clientes", "combien", "conte",
"contar", "cuantos", "departamento", "empleados", "estoque",
"funcionarios", "liste", "listar", "maior", "mas", "media", "menor",
"mostre", "mostrar", "moyenne", "ordenado", "ordene", "par", "por",
"preco", "precio", "produit", "produits", "producto", "productos",
"produto", "produtos", "qual", "quantos", "salaire", "salario", "soma",
"vendas",
}
TABLE_ALIASES = {
"employees": {"employee", "employees", "funcionario", "funcionarios", "empleado", "empleados"},
"orders": {"order", "orders", "pedido", "pedidos"},
"students": {"student", "students", "aluno", "alunos", "alumno", "alumnos", "etudiant", "etudiants"},
"products": {"product", "products", "produto", "produtos", "producto", "productos", "produit", "produits"},
"sales": {"sale", "sales", "venda", "vendas", "venta", "ventas"},
"customers": {"customer", "customers", "cliente", "clientes"},
}
COLUMN_ALIASES = {
"amount": {"amount", "valor", "importe", "montant"},
"category": {"category", "categoria", "categorie"},
"course": {"course", "curso"},
"customer_id": {"customer_id", "cliente_id", "id_cliente"},
"date": {"date", "data", "fecha"},
"department": {"department", "departamento", "departement"},
"grade": {"grade", "nota", "calificacion"},
"name": {"name", "nome", "nombre", "nom"},
"price": {"price", "preco", "precio", "prix"},
"product": {"product", "produto", "producto", "produit"},
"product_id": {"product_id", "produto_id", "producto_id", "id_produto"},
"quantity": {"quantity", "quantidade", "cantidad", "quantite"},
"salary": {"salary", "salario", "salaire", "sueldo"},
"stock": {"stock", "estoque", "inventario"},
"total": {"total", "soma", "sum"},
"weight": {"weight", "peso", "poids"},
"year": {"year", "ano", "anio", "annee"},
}
PHRASE_TRANSLATIONS = (
(r"\bmaior que\b", "greater than"),
(r"\bmenor que\b", "less than"),
(r"\bmaior ou igual a\b", "greater than or equal to"),
(r"\bmenor ou igual a\b", "less than or equal to"),
(r"\bmais caro\b", "most expensive"),
(r"\bmas caro\b", "most expensive"),
(r"\bplus cher\b", "most expensive"),
(r"\bpreco mais alto\b", "highest price"),
(r"\bprecio mas alto\b", "highest price"),
(r"\bmaior preco\b", "highest price"),
(r"\bmais barato\b", "cheapest"),
(r"\bmas barato\b", "cheapest"),
(r"\bplus bas prix\b", "cheapest"),
(r"\bmenor preco\b", "lowest price"),
(r"\bmaior para menor\b", "descending"),
(r"\bmenor para maior\b", "ascending"),
)
TOKEN_TRANSLATIONS = {
"a": "the",
"agrupe": "group",
"alunos": "students",
"barato": "cheap",
"caro": "expensive",
"com": "with",
"conte": "count",
"contar": "count",
"cuantos": "how many",
"da": "of",
"das": "of",
"de": "of",
"do": "of",
"dos": "of",
"e": "and",
"em": "in",
"filtre": "filter",
"filtrar": "filter",
"funcionarios": "employees",
"liste": "list",
"listar": "list",
"maior": "highest",
"mas": "more",
"media": "average",
"menor": "lowest",
"mostre": "show",
"mostrar": "show",
"moyenne": "average",
"o": "the",
"os": "the",
"ordene": "order",
"ordenado": "ordered",
"par": "by",
"para": "for",
"por": "by",
"produto": "product",
"produtos": "products",
"peso": "weight",
"qual": "what",
"quantos": "how many",
"soma": "sum",
"some": "sum",
"todos": "all",
}
def content_to_text(value):
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, dict):
for key in ("text", "content", "value"):
if key in value:
return content_to_text(value[key])
return " ".join(content_to_text(item) for item in value.values())
if isinstance(value, (list, tuple)):
return "\n".join(content_to_text(item) for item in value)
return str(value)
def normalize_text(value):
text = content_to_text(value).lower()
text = unicodedata.normalize("NFKD", text)
text = "".join(char for char in text if not unicodedata.combining(char))
return re.sub(r"\s+", " ", text).strip()
def clean_generation(text):
cleaned = content_to_text(text).strip()
if cleaned.startswith("```"):
lines = cleaned.splitlines()
if lines and lines[0].strip().lower() in {"```", "```sql"}:
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
cleaned = "\n".join(lines).strip()
for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"):
if marker in cleaned:
cleaned = cleaned.split(marker, 1)[0].strip()
if cleaned.upper().startswith("SQL:"):
cleaned = cleaned[4:].strip()
return cleaned
def extract_sql_candidate(text):
cleaned = clean_generation(text)
match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE)
if not match:
return cleaned
return cleaned[match.start() :].strip()
def is_sql_like(text):
text = (text or "").strip()
if not text:
return False
first_word = re.match(r"^\s*([A-Za-z]+)", text)
if not first_word:
return False
return first_word.group(1).upper() in SQL_STARTERS
def is_sql_intent(message, schema=""):
message = normalize_text(message)
if not message:
return False
smalltalk_patterns = {
"oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks",
"thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada",
"vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz",
}
if message in {normalize_text(item) for item in smalltalk_patterns}:
return False
if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")):
return False
sql_terms = {
"all", "average", "count", "columns", "database", "find", "get", "group by",
"highest", "join", "least", "list", "lowest", "max", "maximum", "min", "minimum",
"most", "most expensive", "order by", "query", "rows", "select", "show", "sum",
"where",
"consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar",
"maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene",
"selecione", "some", "soma", "quantos", "filtre", "filtrar",
}
if any(re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message) for term in sql_terms):
return True
return bool(schema and is_sql_like(message))
SQL_SCHEMA_FUNCTION_NAMES = {
"AVG",
"COALESCE",
"COUNT",
"LOWER",
"MAX",
"MIN",
"ROUND",
"SUM",
"UPPER",
}
SQL_ALIAS_STOPWORDS = {
"FULL",
"GROUP",
"HAVING",
"INNER",
"JOIN",
"LEFT",
"LIMIT",
"ON",
"ORDER",
"RIGHT",
"WHERE",
}
def _without_sql_literals(sql_text):
return re.sub(
r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"|\b\d+(?:\.\d+)?\b",
" ",
sql_text or "",
)
def _without_sql_value_literals(sql_text):
without_quoted_values = re.sub(
r"(=|<>|!=|<=|>=|<|>)\s*\"(?:\"\"|[^\"])*\"",
r"\1 ",
sql_text or "",
)
return re.sub(
r"'(?:''|[^'])*'|\b\d+(?:\.\d+)?\b",
" ",
without_quoted_values,
)
def _identifier_names(sql_text):
try:
statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
except Exception:
return []
names = []
for statement in statements:
flattened = list(statement.flatten())
for index, token in enumerate(flattened):
if token.ttype not in sqlparse.tokens.Name and token.ttype is not sqlparse.tokens.Literal.String.Symbol:
continue
previous_value = ""
next_value = ""
for previous in reversed(flattened[:index]):
if not previous.is_whitespace:
previous_value = previous.value
break
for next_token in flattened[index + 1:]:
if not next_token.is_whitespace:
next_value = next_token.value
break
names.append((token.value.strip('"'), previous_value, next_value))
return names
def sql_schema_validation_issue(sql_text, schema):
table_name, columns = parse_create_table_schema(schema)
if not table_name or not columns:
return ""
expected_table = table_name.lower()
allowed_columns = {name.lower() for name, _column_type in columns}
scrubbed_sql = _without_sql_literals(sql_text)
cte_aliases = {
match.group(1).lower()
for match in re.finditer(
r"(?:WITH|,)\s+([A-Za-z_][\w]*)\s+AS\s*\(",
scrubbed_sql,
flags=re.IGNORECASE,
)
}
table_refs = [
match.group(1).lower()
for match in re.finditer(
r"\b(?:FROM|JOIN)\s+([A-Za-z_][\w]*)",
scrubbed_sql,
flags=re.IGNORECASE,
)
]
if not table_refs:
return f"Model SQL does not reference active table: {table_name}"
for table_ref in table_refs:
if table_ref not in {expected_table, *cte_aliases}:
return f"Unknown table for active schema: {table_ref}"
table_aliases = set()
table_alias_pattern = (
r"\b(?:FROM|JOIN)\s+"
rf"({re.escape(table_name)})"
r"\s+(?:AS\s+)?([A-Za-z_][\w]*)"
)
for match in re.finditer(table_alias_pattern, scrubbed_sql, flags=re.IGNORECASE):
alias = match.group(2)
if alias.upper() not in SQL_ALIAS_STOPWORDS:
table_aliases.add(alias.lower())
identifier_sql = _without_sql_value_literals(sql_text)
output_aliases = {
(match.group(1) or match.group(2)).lower()
for match in re.finditer(
r"\bAS\s+(?:\"([^\"]+)\"|([A-Za-z_][\w]*))",
identifier_sql,
flags=re.IGNORECASE,
)
}
allowed_non_columns = {
expected_table,
*cte_aliases,
*table_aliases,
*output_aliases,
}
for name, _previous_value, next_value in _identifier_names(identifier_sql):
normalized_name = name.lower()
if next_value == ".":
continue
if normalized_name in allowed_non_columns:
continue
if name.upper() in SQL_SCHEMA_FUNCTION_NAMES:
continue
if normalized_name in allowed_columns:
continue
return f"Unknown column for active schema: {name}"
return ""
def validate_sql(sql_text, schema=""):
sql_text = (sql_text or "").strip()
if not sql_text:
return '<span class="validator-badge validator-empty">No SQL yet</span>'
try:
statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()]
except Exception as exc:
error_type = html.escape(type(exc).__name__)
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">sqlparse error: {error_type}</span>'
)
if not statements:
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
'<span class="validator-detail">No parsed SQL statement.</span>'
)
first_token = statements[0].token_first(skip_cm=True)
token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN"
if token_value not in SQL_STARTERS:
escaped_token = html.escape(token_value)
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">First token: {escaped_token}</span>'
)
trailing_keyword = re.search(
r"\b(AND|BY|FROM|GROUP|HAVING|JOIN|LIMIT|NOT|ON|OR|ORDER|SELECT|WHERE)\s*;?\s*$",
sql_text,
flags=re.IGNORECASE,
)
if trailing_keyword:
escaped_token = html.escape(trailing_keyword.group(1).upper())
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">Incomplete trailing clause: {escaped_token}</span>'
)
bare_negated_predicate = re.search(
r"\b(AND|HAVING|OR|WHERE)\s+NOT\s+([A-Za-z_][\w.]*)\s*;?\s*$",
sql_text,
flags=re.IGNORECASE,
)
if bare_negated_predicate:
escaped_token = html.escape(bare_negated_predicate.group(2))
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">Incomplete negated predicate: NOT {escaped_token}</span>'
)
trailing_comparison = re.search(r"(=|<>|!=|<=|>=|<|>)\s*;?\s*$", sql_text)
if trailing_comparison:
escaped_token = html.escape(trailing_comparison.group(1))
return (
'<span class="validator-badge validator-warn">Check syntax</span>'
f'<span class="validator-detail">Incomplete comparison operator: {escaped_token}</span>'
)
schema_issue = sql_schema_validation_issue(sql_text, schema)
if schema_issue:
escaped_issue = html.escape(schema_issue)
return (
'<span class="validator-badge validator-warn">Check schema</span>'
f'<span class="validator-detail">{escaped_issue}</span>'
)
return '<span class="validator-badge validator-ok">Valid SQL</span>'
# Verb stems for create-table intent — keep in sync with create_table_from_message below.
_CREATE_VERBS = (
r"create|make|build|generate"
r"|criar|crie|cria|criando"
r"|gerar|gere|gera|gerando"
r"|faz|faca|fa\u00e7a|fazendo"
r"|monta|montar|monte"
r"|construa|construir|constroi"
r"|elabore|elaborar|elabora"
)
def is_create_table_intent(message):
message = (message or "").strip().lower()
return bool(
re.search(rf"\b({_CREATE_VERBS})\b", message)
and re.search(r"\b(table|schema|tabela)\b", message)
)
def is_rename_intent(message):
return bool(extract_renamed_columns(message))
def is_table_edit_intent(message):
message = (message or "").strip().lower()
edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|exclui|novo|nova|troca|trocar|coloque|colocar|coloca|insira|insere|bota|tira|retire|retira|apaga|apague)\b"
direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar|coloca|acrescenta|insere|inserir|insira|bota|botar|bote)\b"
direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b"
target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b"
sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by", "soma", "media", "contagem", "maximo", "minimo"}
add_match = re.search(direct_add_terms, message)
if add_match:
after_add = message[add_match.start() + len(add_match.group()) :].strip()
first_word_after = after_add.split()[0] if after_add.split() else ""
is_add_intent = first_word_after not in sql_aggregation_terms
else:
is_add_intent = False
return bool(
is_add_intent
or re.search(direct_remove_terms, message)
or is_rename_intent(message)
or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message)
or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message)))
)
def is_unsupported_data_mutation_intent(message):
message = normalize_text(message)
if not message:
return False
if is_create_table_intent(message):
return False
explicit_dml = (
r"\b(?:insert\s+into|update\b.+\bset\b|delete\s+from|"
r"drop\s+table|drop\s+(?:the\s+|a\s+|an\s+)?\w+\s+table|truncate\s+table)\b"
)
if re.search(explicit_dml, message):
return True
if re.search(r"\binsert\b", message) and not re.search(r"\b(column|field|coluna|campo)\b", message):
return True
row_terms = r"(?:row|rows|record|records|linha|linhas|registro|registros)"
row_mutation = (
rf"\b(?:add|create|insert|include|remove|delete|drop|"
rf"adicionar|adicione|criar|crie|incluir|inclua|remover|remova|deletar|excluir|exclua|apagar|apague)\b"
rf".*\b{row_terms}\b"
)
if re.search(row_mutation, message):
return True
destructive_all = (
r"\b(?:delete|drop|remove|deletar|exclua|excluir|apaga|apagar|apague|remova|remover)\b"
r"\s+(?:all|every|rows?|records?|table|todos|todas|linhas?|registros?|tabela)\b"
)
return bool(re.search(destructive_all, message))
def _table_name_variants(table_name):
base = normalize_text(table_name)
if not base:
return set()
variants = {base}
if base.endswith("ies") and len(base) > 3:
variants.add(f"{base[:-3]}y")
elif base.endswith("s") and len(base) > 1:
variants.add(base[:-1])
else:
variants.add(f"{base}s")
return {variant for variant in variants if variant}
def is_unsupported_data_mutation_for_schema(message, active_schema=""):
if is_unsupported_data_mutation_intent(message):
return True
table_name, _columns = parse_create_table_schema(create_table_from_schema(active_schema))
if not table_name:
return False
message = normalize_text(message)
table_pattern = "|".join(
re.escape(variant)
for variant in sorted(_table_name_variants(table_name), key=len, reverse=True)
)
table_mutation = (
rf"\b(?:delete|drop|remove|deletar|excluir|exclua|remover|remova)\b"
rf"\s+(?:the\s+|a\s+|an\s+)?(?:table\s+)?(?:{table_pattern})\b(?:\s+table)?"
)
return bool(re.search(table_mutation, message))
def infer_column_type(column_name):
name = column_name.strip().lower()
if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}:
return "INTEGER"
if name in {
"salary", "price", "preco", "amount", "total", "grade", "peso", "weight",
"idade", "age", "altura", "height", "largura", "width", "comprimento",
"length", "desconto", "discount",
}:
return "NUMERIC"
if name in {"date", "created_at", "updated_at"} or name.endswith("_date"):
return "DATE"
return "TEXT"
def normalize_identifier(value):
identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_")
if not identifier:
return ""
if identifier[0].isdigit():
identifier = f"col_{identifier}"
return identifier
def parse_column_definition(raw_column):
raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE)
raw_column = raw_column.strip(" .;:")
if not raw_column:
return None
type_matches = list(
re.finditer(
r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b",
raw_column,
flags=re.IGNORECASE,
)
)
explicit_type = type_matches[-1] if type_matches else None
if explicit_type:
name_part = raw_column[: explicit_type.start()].strip()
column_type = explicit_type.group(1).upper()
if column_type == "INT":
column_type = "INTEGER"
elif column_type == "BOOL":
column_type = "BOOLEAN"
elif column_type == "DECIMAL":
column_type = "NUMERIC"
elif column_type in {"FLOAT", "DOUBLE"}:
column_type = "REAL"
if not name_part.strip():
column_type = None
name_part = raw_column
else:
name_part = raw_column
column_type = None
name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE)
column_name = normalize_identifier(name_part)
if not column_name:
return None
return column_name, column_type or infer_column_type(column_name)
def split_column_list(columns_text):
columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE)
parts = []
type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b"
type_tokens = {
"integer", "int", "numeric", "decimal", "real", "float", "double",
"text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool",
}
stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"}
for part in (item.strip() for item in columns_text.split(",") if item.strip()):
tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()]
tokens = [token for token in tokens if token.lower() not in stopwords]
if not tokens:
continue
if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2:
index = 0
inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"}
while index < len(tokens):
current = tokens[index]
next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else ""
if next_token in type_tokens and not (
current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"}
):
parts.append(f"{current} {tokens[index + 1]}")
index += 2
else:
parts.append(current)
index += 1
continue
if re.search(type_pattern, part, flags=re.IGNORECASE):
parts.append(part)
continue
if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens):
parts.extend(tokens)
else:
parts.append(part)
return parts
def format_create_table(table_name, columns):
if not table_name or not columns:
return ""
seen = set()
column_lines = []
for column_name, column_type in columns:
if column_name in seen:
continue
seen.add(column_name)
column_lines.append(f" {column_name} {column_type}")
if not column_lines:
return ""
return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);"
def create_table_from_message(message):
message = (message or "").strip()
patterns = (
r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
rf"\b(?:{_CREATE_VERBS})\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
table_name = normalize_identifier(match.group(1))
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
if parsed
]
return format_create_table(table_name, columns)
return ""
def parse_create_table_schema(schema):
schema = (schema or "").strip()
match = re.match(
r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$",
schema,
flags=re.IGNORECASE | re.DOTALL,
)
if not match:
return "", []
table_name = normalize_identifier(match.group(1))
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2)))
if parsed
]
return table_name, columns
def _schema_column_names(schema):
_table_name, columns = parse_create_table_schema(schema)
return [name for name, _column_type in columns]
def _canonical_column(name):
normalized = normalize_text(name)
for canonical, aliases in COLUMN_ALIASES.items():
if normalized == canonical or normalized in aliases:
return canonical
return normalized
_NO_COLUMN_DEFAULT = object()
def _column_from_message(normalized_message, schema_columns, canonical, default=_NO_COLUMN_DEFAULT):
aliases = COLUMN_ALIASES.get(canonical, {canonical})
if not any(re.search(rf"(?<!\w){re.escape(alias)}(?!\w)", normalized_message) for alias in aliases):
return None if default is _NO_COLUMN_DEFAULT else default
for column in schema_columns:
if _canonical_column(column) == canonical:
return column
return canonical if default is _NO_COLUMN_DEFAULT else default
NUMERIC_COLUMN_CANONICALS = {"price", "salary", "grade", "amount", "total", "quantity", "stock", "weight", "year"}
def _first_numeric_focus(normalized_message, schema_columns, default="value"):
for canonical in ("price", "salary", "grade", "amount", "total", "quantity", "stock", "weight", "year"):
selected = _column_from_message(normalized_message, schema_columns, canonical, None)
if selected:
return selected
for column in schema_columns:
if _canonical_column(column) in NUMERIC_COLUMN_CANONICALS:
return column
return default
def _schema_column_for_canonical(schema_columns, canonical):
for column in schema_columns:
if _canonical_column(column) == canonical:
return column
return ""
def _schema_has_column(schema_columns, column_name):
return any(column == column_name for column in schema_columns)
def _mentioned_schema_column(normalized_message, schema_columns):
for column in schema_columns:
normalized_column = normalize_text(column)
canonical = _canonical_column(column)
aliases = {normalized_column, canonical, *COLUMN_ALIASES.get(canonical, set())}
if any(re.search(rf"(?<!\w){re.escape(alias)}(?!\w)", normalized_message) for alias in aliases):
return column
return ""
def _dimension_after_group_marker(normalized_message, schema_columns):
match = re.search(r"\b(?:by|por|par|per)\s+([A-Za-z_][\w]*)", normalized_message)
if not match:
return ""
token = normalize_text(match.group(1))
for canonical, aliases in COLUMN_ALIASES.items():
if token == canonical or token in aliases:
return _column_from_message(token, schema_columns, canonical, canonical)
return normalize_identifier(token)
def _table_subject(schema):
table_name, _columns = parse_create_table_schema(schema)
normalized = normalize_text(table_name)
for canonical, aliases in TABLE_ALIASES.items():
if normalized == canonical or normalized in aliases:
return canonical[:-1] if canonical.endswith("s") else canonical
if normalized.endswith("s") and len(normalized) > 3:
return normalized[:-1]
return normalized or "row"
def _has_marker(normalized_message, markers):
return any(marker in normalized_message for marker in markers)
def _clean_english_question(text):
text = re.sub(r"\s+", " ", text or "").strip(" .;:")
if not text:
return LANGUAGE_FALLBACK_QUESTION
return text if text.endswith("?") else f"{text}?"
def _lexical_translate_sql_question(normalized_message):
text = normalized_message
for pattern, replacement in PHRASE_TRANSLATIONS:
text = re.sub(pattern, replacement, text)
words = []
for raw_word in re.findall(r"[A-Za-z_][\w]*|\d+(?:\.\d+)?|[<>=]+", text):
normalized_word = normalize_text(raw_word)
translated = TOKEN_TRANSLATIONS.get(normalized_word)
if not translated:
for canonical, aliases in COLUMN_ALIASES.items():
if normalized_word == canonical or normalized_word in aliases:
translated = canonical
break
if not translated:
for canonical, aliases in TABLE_ALIASES.items():
if normalized_word == canonical or normalized_word in aliases:
translated = canonical
break
if translated:
words.extend(translated.split())
else:
words.append(raw_word)
return " ".join(words)
def normalize_sql_question_to_english(message, schema=""):
raw_message = (message or "").strip()
normalized_message = normalize_text(raw_message)
if not normalized_message:
return LANGUAGE_FALLBACK_QUESTION
schema_columns = _schema_column_names(schema)
subject = _table_subject(schema)
if _has_marker(normalized_message, ("mais caro", "mas caro", "plus cher", "maior preco", "highest price", "most expensive")):
return _clean_english_question(f"what is the most expensive {subject}")
if _has_marker(normalized_message, ("mais barato", "mas barato", "menor preco", "lowest price", "cheapest")):
return _clean_english_question(f"what is the cheapest {subject}")
if _has_marker(normalized_message, ("media", "moyenne", "promedio", "average")):
metric = _first_numeric_focus(normalized_message, schema_columns)
dimension = _dimension_after_group_marker(normalized_message, schema_columns)
if dimension:
return _clean_english_question(f"what is the average {metric} by {dimension}")
return _clean_english_question(f"what is the average {metric}")
if _has_marker(normalized_message, ("quantos", "cuantos", "combien", "count", "how many")):
dimension = _dimension_after_group_marker(normalized_message, schema_columns)
if dimension:
return _clean_english_question(f"count {subject}s by {dimension}")
return _clean_english_question(f"how many {subject}s are there")
translated = _lexical_translate_sql_question(normalized_message)
has_foreign_marker = any(
re.search(rf"(?<!\w){re.escape(marker)}(?!\w)", normalized_message)
for marker in FOREIGN_SQL_MARKERS
)
has_ascii_words = bool(re.search(r"[A-Za-z]", translated))
if has_foreign_marker and has_ascii_words:
return _clean_english_question(translated)
if has_ascii_words:
return raw_message
return LANGUAGE_FALLBACK_QUESTION
def deterministic_sql_query(message, schema=""):
table_name, columns = parse_create_table_schema(schema)
schema_columns = [name for name, _column_type in columns]
if not table_name or not schema_columns:
return ""
normalized_message = normalize_text(message)
english_question = normalize_sql_question_to_english(message, schema)
normalized_english = normalize_text(english_question)
if _has_marker(normalized_message, ("mais caro", "mas caro", "plus cher", "maior preco")) or _has_marker(
normalized_english, ("most expensive", "highest price")
):
metric = _schema_column_for_canonical(schema_columns, "price") or _schema_column_for_canonical(schema_columns, "amount")
if metric and _schema_has_column(schema_columns, metric):
return f"SELECT * FROM {table_name} ORDER BY {metric} DESC LIMIT 1;"
if _has_marker(normalized_message, ("mais barato", "mas barato", "menor preco")) or _has_marker(
normalized_english, ("cheapest", "lowest price")
):
metric = _schema_column_for_canonical(schema_columns, "price") or _schema_column_for_canonical(schema_columns, "amount")
if metric and _schema_has_column(schema_columns, metric):
return f"SELECT * FROM {table_name} ORDER BY {metric} ASC LIMIT 1;"
if _has_marker(normalized_message, ("media", "moyenne", "promedio")) or "average" in normalized_english:
metric = _first_numeric_focus(normalized_message, schema_columns, "")
dimension = _dimension_after_group_marker(normalized_message, schema_columns)
if metric and dimension and _schema_has_column(schema_columns, metric) and _schema_has_column(schema_columns, dimension):
return (
f"SELECT {dimension}, AVG({metric}) AS average_{metric} "
f"FROM {table_name} GROUP BY {dimension};"
)
if metric and _schema_has_column(schema_columns, metric):
return f"SELECT AVG({metric}) AS average_{metric} FROM {table_name};"
if _has_marker(normalized_message, ("quantos", "cuantos", "combien")) or "how many" in normalized_english:
dimension = _dimension_after_group_marker(normalized_message, schema_columns)
if dimension and _schema_has_column(schema_columns, dimension):
return f"SELECT {dimension}, COUNT(*) AS row_count FROM {table_name} GROUP BY {dimension};"
return f"SELECT COUNT(*) AS row_count FROM {table_name};"
comparison_match = re.search(
r"\b([A-Za-z_][\w]*)\s+(greater than|less than)\s+(\d+(?:\.\d+)?)\b",
normalized_english,
)
if comparison_match:
column = _canonical_column(comparison_match.group(1))
actual_column = _schema_column_for_canonical(schema_columns, column)
if actual_column:
operator = ">" if comparison_match.group(2) == "greater than" else "<"
return f"SELECT * FROM {table_name} WHERE {actual_column} {operator} {comparison_match.group(3)};"
ranking_column = _mentioned_schema_column(normalized_message, schema_columns)
if ranking_column and _canonical_column(ranking_column) in NUMERIC_COLUMN_CANONICALS:
if (
_has_marker(normalized_message, ("maior", "maximo", "mais alto"))
or _has_marker(normalized_english, ("highest", "maximum", "max"))
) and "maior que" not in normalized_message:
return f"SELECT * FROM {table_name} ORDER BY {ranking_column} DESC LIMIT 1;"
if (
_has_marker(normalized_message, ("menor", "minimo", "mais baixo"))
or _has_marker(normalized_english, ("lowest", "minimum", "min"))
) and "menor que" not in normalized_message:
return f"SELECT * FROM {table_name} ORDER BY {ranking_column} ASC LIMIT 1;"
if normalized_english in {"list all", "show all", "select all rows", "list all rows", "show all rows"}:
return f"SELECT * FROM {table_name};"
return ""
def create_table_from_schema(schema):
table_name, columns = parse_create_table_schema(schema)
return format_create_table(table_name, columns)
def extract_create_table_statement(text):
cleaned = extract_sql_candidate(text)
match = re.search(
r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?",
cleaned,
flags=re.IGNORECASE | re.DOTALL,
)
return clean_generation(match.group(0)) if match else ""
def last_create_table_from_history(chat_history):
for item in reversed(list(chat_history or [])):
if not isinstance(item, dict) or item.get("role") != "assistant":
continue
statement = extract_create_table_statement(item.get("content", ""))
if statement:
return statement
return ""
def extract_added_columns(message):
message = (message or "").strip()
patterns = (
r":\s*(.+)$",
r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
columns = [
parsed
for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1)))
if parsed
]
if columns:
return columns
return []
def extract_removed_columns(message):
message = (message or "").strip()
patterns = (
r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$",
)
for pattern in patterns:
match = re.search(pattern, message, flags=re.IGNORECASE)
if not match:
continue
columns = [normalize_identifier(column) for column in split_column_list(match.group(1))]
columns = [column for column in columns if column]
if columns:
return columns
return []
def extract_renamed_columns(message):
pattern = (
r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|alterar|altera|mude|mudar|muda|edita|editar)\s+"
r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)"
)
matches = re.findall(pattern, message or "", flags=re.IGNORECASE)
troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE)
invalid_old_names = {"ela", "ele", "isso", "isto", "essa", "esse", "this", "it"}
invalid_new_names = {"ter", "have", "having", "tambem", "tambem"}
return [
(normalize_identifier(old), normalize_identifier(new))
for old, new in [*matches, *troca_matches]
if normalize_identifier(old)
and normalize_identifier(new)
and normalize_identifier(old) not in invalid_old_names
and normalize_identifier(new) not in invalid_new_names
]
def parse_compound_edit(message):
segment_pattern = (
r"\s+(?:and|e)\s+"
r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|"
r"adicione|adicionar|adicionando|inclua|incluir|acrescente|"
r"coloca|coloque|bota|insira|insere|"
r"remova|remover|deletar|exclua|excluir|exclui|"
r"tira|tirar|tire|retira|retire|apaga|apague|"
r"renomeie|renomear|renomeia|altere|alterar|altera|"
r"mude|mudar|muda|edita|editar|troca|trocar)\b)"
)
segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE)
added, removed, renamed = [], [], []
for seg in segments:
seg = seg.strip()
if not seg:
continue
if is_rename_intent(seg):
renamed.extend(extract_renamed_columns(seg))
elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE):
removed.extend(extract_removed_columns(seg))
else:
cols = extract_added_columns(seg)
if cols:
added.extend(cols)
return added, removed, renamed
def edit_create_table_from_message(message, chat_history, active_schema):
if not is_table_edit_intent(message) and not is_rename_intent(message):
return ""
base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema)
table_name, existing_columns = parse_create_table_schema(base_sql)
if not table_name:
return ""
added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message)
removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list}
if not added_columns and not removed_set and not renamed_columns:
return ""
rename_map = dict(renamed_columns)
kept_columns = [
(rename_map.get(col_name, col_name), col_type)
for col_name, col_type in existing_columns
if col_name not in removed_set
]
updated_columns = [*kept_columns, *added_columns]
if updated_columns == existing_columns:
return ""
return format_create_table(table_name, updated_columns)
def create_table_from_suggestion(suggestion):
if not suggestion:
return ""
if isinstance(suggestion, dict):
table_name = suggestion.get("table_name")
columns = [
(column.get("name"), column.get("type", "TEXT"))
for column in suggestion.get("columns", [])
if isinstance(column, dict)
]
else:
table_name = getattr(suggestion, "table_name", "")
columns = getattr(suggestion, "columns", ())
parsed = []
for name, column_type in columns:
identifier = normalize_identifier(name)
if identifier:
parsed.append((identifier, (column_type or "TEXT").upper()))
return format_create_table(normalize_identifier(table_name), parsed)