| import html |
| import re |
| import unicodedata |
|
|
| import sqlparse |
|
|
|
|
| SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"} |
|
|
| LANGUAGE_FALLBACK_QUESTION = "write a SELECT query for this table." |
|
|
| FOREIGN_SQL_MARKERS = { |
| "alumnos", "alunos", "barato", "caro", "clientes", "combien", "conte", |
| "contar", "cuantos", "departamento", "empleados", "estoque", |
| "funcionarios", "liste", "listar", "maior", "mas", "media", "menor", |
| "mostre", "mostrar", "moyenne", "ordenado", "ordene", "par", "por", |
| "preco", "precio", "produit", "produits", "producto", "productos", |
| "produto", "produtos", "qual", "quantos", "salaire", "salario", "soma", |
| "vendas", |
| } |
|
|
| TABLE_ALIASES = { |
| "employees": {"employee", "employees", "funcionario", "funcionarios", "empleado", "empleados"}, |
| "orders": {"order", "orders", "pedido", "pedidos"}, |
| "students": {"student", "students", "aluno", "alunos", "alumno", "alumnos", "etudiant", "etudiants"}, |
| "products": {"product", "products", "produto", "produtos", "producto", "productos", "produit", "produits"}, |
| "sales": {"sale", "sales", "venda", "vendas", "venta", "ventas"}, |
| "customers": {"customer", "customers", "cliente", "clientes"}, |
| } |
|
|
| COLUMN_ALIASES = { |
| "amount": {"amount", "valor", "importe", "montant"}, |
| "category": {"category", "categoria", "categorie"}, |
| "course": {"course", "curso"}, |
| "customer_id": {"customer_id", "cliente_id", "id_cliente"}, |
| "date": {"date", "data", "fecha"}, |
| "department": {"department", "departamento", "departement"}, |
| "grade": {"grade", "nota", "calificacion"}, |
| "name": {"name", "nome", "nombre", "nom"}, |
| "price": {"price", "preco", "precio", "prix"}, |
| "product": {"product", "produto", "producto", "produit"}, |
| "product_id": {"product_id", "produto_id", "producto_id", "id_produto"}, |
| "quantity": {"quantity", "quantidade", "cantidad", "quantite"}, |
| "salary": {"salary", "salario", "salaire", "sueldo"}, |
| "stock": {"stock", "estoque", "inventario"}, |
| "total": {"total", "soma", "sum"}, |
| "weight": {"weight", "peso", "poids"}, |
| "year": {"year", "ano", "anio", "annee"}, |
| } |
|
|
| PHRASE_TRANSLATIONS = ( |
| (r"\bmaior que\b", "greater than"), |
| (r"\bmenor que\b", "less than"), |
| (r"\bmaior ou igual a\b", "greater than or equal to"), |
| (r"\bmenor ou igual a\b", "less than or equal to"), |
| (r"\bmais caro\b", "most expensive"), |
| (r"\bmas caro\b", "most expensive"), |
| (r"\bplus cher\b", "most expensive"), |
| (r"\bpreco mais alto\b", "highest price"), |
| (r"\bprecio mas alto\b", "highest price"), |
| (r"\bmaior preco\b", "highest price"), |
| (r"\bmais barato\b", "cheapest"), |
| (r"\bmas barato\b", "cheapest"), |
| (r"\bplus bas prix\b", "cheapest"), |
| (r"\bmenor preco\b", "lowest price"), |
| (r"\bmaior para menor\b", "descending"), |
| (r"\bmenor para maior\b", "ascending"), |
| ) |
|
|
| TOKEN_TRANSLATIONS = { |
| "a": "the", |
| "agrupe": "group", |
| "alunos": "students", |
| "barato": "cheap", |
| "caro": "expensive", |
| "com": "with", |
| "conte": "count", |
| "contar": "count", |
| "cuantos": "how many", |
| "da": "of", |
| "das": "of", |
| "de": "of", |
| "do": "of", |
| "dos": "of", |
| "e": "and", |
| "em": "in", |
| "filtre": "filter", |
| "filtrar": "filter", |
| "funcionarios": "employees", |
| "liste": "list", |
| "listar": "list", |
| "maior": "highest", |
| "mas": "more", |
| "media": "average", |
| "menor": "lowest", |
| "mostre": "show", |
| "mostrar": "show", |
| "moyenne": "average", |
| "o": "the", |
| "os": "the", |
| "ordene": "order", |
| "ordenado": "ordered", |
| "par": "by", |
| "para": "for", |
| "por": "by", |
| "produto": "product", |
| "produtos": "products", |
| "peso": "weight", |
| "qual": "what", |
| "quantos": "how many", |
| "soma": "sum", |
| "some": "sum", |
| "todos": "all", |
| } |
|
|
|
|
| def content_to_text(value): |
| if value is None: |
| return "" |
| if isinstance(value, str): |
| return value |
| if isinstance(value, dict): |
| for key in ("text", "content", "value"): |
| if key in value: |
| return content_to_text(value[key]) |
| return " ".join(content_to_text(item) for item in value.values()) |
| if isinstance(value, (list, tuple)): |
| return "\n".join(content_to_text(item) for item in value) |
| return str(value) |
|
|
|
|
| def normalize_text(value): |
| text = content_to_text(value).lower() |
| text = unicodedata.normalize("NFKD", text) |
| text = "".join(char for char in text if not unicodedata.combining(char)) |
| return re.sub(r"\s+", " ", text).strip() |
|
|
|
|
| def clean_generation(text): |
| cleaned = content_to_text(text).strip() |
| if cleaned.startswith("```"): |
| lines = cleaned.splitlines() |
| if lines and lines[0].strip().lower() in {"```", "```sql"}: |
| lines = lines[1:] |
| if lines and lines[-1].strip() == "```": |
| lines = lines[:-1] |
| cleaned = "\n".join(lines).strip() |
| for marker in ("<|end|>", "<|user|>", "<|assistant|>", "</s>"): |
| if marker in cleaned: |
| cleaned = cleaned.split(marker, 1)[0].strip() |
| if cleaned.upper().startswith("SQL:"): |
| cleaned = cleaned[4:].strip() |
| return cleaned |
|
|
|
|
| def extract_sql_candidate(text): |
| cleaned = clean_generation(text) |
| match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE) |
| if not match: |
| return cleaned |
| return cleaned[match.start() :].strip() |
|
|
|
|
| def is_sql_like(text): |
| text = (text or "").strip() |
| if not text: |
| return False |
| first_word = re.match(r"^\s*([A-Za-z]+)", text) |
| if not first_word: |
| return False |
| return first_word.group(1).upper() in SQL_STARTERS |
|
|
|
|
| def is_sql_intent(message, schema=""): |
| message = normalize_text(message) |
| if not message: |
| return False |
| smalltalk_patterns = { |
| "oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks", |
| "thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada", |
| "vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz", |
| } |
| if message in {normalize_text(item) for item in smalltalk_patterns}: |
| return False |
| if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")): |
| return False |
| sql_terms = { |
| "all", "average", "count", "columns", "database", "find", "get", "group by", |
| "highest", "join", "least", "list", "lowest", "max", "maximum", "min", "minimum", |
| "most", "most expensive", "order by", "query", "rows", "select", "show", "sum", |
| "where", |
| "consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar", |
| "maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene", |
| "selecione", "some", "soma", "quantos", "filtre", "filtrar", |
| } |
| if any(re.search(rf"(?<!\w){re.escape(normalize_text(term))}(?!\w)", message) for term in sql_terms): |
| return True |
| return bool(schema and is_sql_like(message)) |
|
|
|
|
| SQL_SCHEMA_FUNCTION_NAMES = { |
| "AVG", |
| "COALESCE", |
| "COUNT", |
| "LOWER", |
| "MAX", |
| "MIN", |
| "ROUND", |
| "SUM", |
| "UPPER", |
| } |
|
|
| SQL_ALIAS_STOPWORDS = { |
| "FULL", |
| "GROUP", |
| "HAVING", |
| "INNER", |
| "JOIN", |
| "LEFT", |
| "LIMIT", |
| "ON", |
| "ORDER", |
| "RIGHT", |
| "WHERE", |
| } |
|
|
|
|
| def _without_sql_literals(sql_text): |
| return re.sub( |
| r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"|\b\d+(?:\.\d+)?\b", |
| " ", |
| sql_text or "", |
| ) |
|
|
|
|
| def _without_sql_value_literals(sql_text): |
| without_quoted_values = re.sub( |
| r"(=|<>|!=|<=|>=|<|>)\s*\"(?:\"\"|[^\"])*\"", |
| r"\1 ", |
| sql_text or "", |
| ) |
| return re.sub( |
| r"'(?:''|[^'])*'|\b\d+(?:\.\d+)?\b", |
| " ", |
| without_quoted_values, |
| ) |
|
|
|
|
| def _identifier_names(sql_text): |
| try: |
| statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()] |
| except Exception: |
| return [] |
| names = [] |
| for statement in statements: |
| flattened = list(statement.flatten()) |
| for index, token in enumerate(flattened): |
| if token.ttype not in sqlparse.tokens.Name and token.ttype is not sqlparse.tokens.Literal.String.Symbol: |
| continue |
| previous_value = "" |
| next_value = "" |
| for previous in reversed(flattened[:index]): |
| if not previous.is_whitespace: |
| previous_value = previous.value |
| break |
| for next_token in flattened[index + 1:]: |
| if not next_token.is_whitespace: |
| next_value = next_token.value |
| break |
| names.append((token.value.strip('"'), previous_value, next_value)) |
| return names |
|
|
|
|
| def sql_schema_validation_issue(sql_text, schema): |
| table_name, columns = parse_create_table_schema(schema) |
| if not table_name or not columns: |
| return "" |
|
|
| expected_table = table_name.lower() |
| allowed_columns = {name.lower() for name, _column_type in columns} |
| scrubbed_sql = _without_sql_literals(sql_text) |
|
|
| cte_aliases = { |
| match.group(1).lower() |
| for match in re.finditer( |
| r"(?:WITH|,)\s+([A-Za-z_][\w]*)\s+AS\s*\(", |
| scrubbed_sql, |
| flags=re.IGNORECASE, |
| ) |
| } |
| table_refs = [ |
| match.group(1).lower() |
| for match in re.finditer( |
| r"\b(?:FROM|JOIN)\s+([A-Za-z_][\w]*)", |
| scrubbed_sql, |
| flags=re.IGNORECASE, |
| ) |
| ] |
| if not table_refs: |
| return f"Model SQL does not reference active table: {table_name}" |
| for table_ref in table_refs: |
| if table_ref not in {expected_table, *cte_aliases}: |
| return f"Unknown table for active schema: {table_ref}" |
|
|
| table_aliases = set() |
| table_alias_pattern = ( |
| r"\b(?:FROM|JOIN)\s+" |
| rf"({re.escape(table_name)})" |
| r"\s+(?:AS\s+)?([A-Za-z_][\w]*)" |
| ) |
| for match in re.finditer(table_alias_pattern, scrubbed_sql, flags=re.IGNORECASE): |
| alias = match.group(2) |
| if alias.upper() not in SQL_ALIAS_STOPWORDS: |
| table_aliases.add(alias.lower()) |
|
|
| identifier_sql = _without_sql_value_literals(sql_text) |
| output_aliases = { |
| (match.group(1) or match.group(2)).lower() |
| for match in re.finditer( |
| r"\bAS\s+(?:\"([^\"]+)\"|([A-Za-z_][\w]*))", |
| identifier_sql, |
| flags=re.IGNORECASE, |
| ) |
| } |
| allowed_non_columns = { |
| expected_table, |
| *cte_aliases, |
| *table_aliases, |
| *output_aliases, |
| } |
| for name, _previous_value, next_value in _identifier_names(identifier_sql): |
| normalized_name = name.lower() |
| if next_value == ".": |
| continue |
| if normalized_name in allowed_non_columns: |
| continue |
| if name.upper() in SQL_SCHEMA_FUNCTION_NAMES: |
| continue |
| if normalized_name in allowed_columns: |
| continue |
| return f"Unknown column for active schema: {name}" |
| return "" |
|
|
|
|
| def validate_sql(sql_text, schema=""): |
| sql_text = (sql_text or "").strip() |
| if not sql_text: |
| return '<span class="validator-badge validator-empty">No SQL yet</span>' |
| try: |
| statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()] |
| except Exception as exc: |
| error_type = html.escape(type(exc).__name__) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">sqlparse error: {error_type}</span>' |
| ) |
| if not statements: |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| '<span class="validator-detail">No parsed SQL statement.</span>' |
| ) |
| first_token = statements[0].token_first(skip_cm=True) |
| token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN" |
| if token_value not in SQL_STARTERS: |
| escaped_token = html.escape(token_value) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">First token: {escaped_token}</span>' |
| ) |
| trailing_keyword = re.search( |
| r"\b(AND|BY|FROM|GROUP|HAVING|JOIN|LIMIT|NOT|ON|OR|ORDER|SELECT|WHERE)\s*;?\s*$", |
| sql_text, |
| flags=re.IGNORECASE, |
| ) |
| if trailing_keyword: |
| escaped_token = html.escape(trailing_keyword.group(1).upper()) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">Incomplete trailing clause: {escaped_token}</span>' |
| ) |
| bare_negated_predicate = re.search( |
| r"\b(AND|HAVING|OR|WHERE)\s+NOT\s+([A-Za-z_][\w.]*)\s*;?\s*$", |
| sql_text, |
| flags=re.IGNORECASE, |
| ) |
| if bare_negated_predicate: |
| escaped_token = html.escape(bare_negated_predicate.group(2)) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">Incomplete negated predicate: NOT {escaped_token}</span>' |
| ) |
| trailing_comparison = re.search(r"(=|<>|!=|<=|>=|<|>)\s*;?\s*$", sql_text) |
| if trailing_comparison: |
| escaped_token = html.escape(trailing_comparison.group(1)) |
| return ( |
| '<span class="validator-badge validator-warn">Check syntax</span>' |
| f'<span class="validator-detail">Incomplete comparison operator: {escaped_token}</span>' |
| ) |
| schema_issue = sql_schema_validation_issue(sql_text, schema) |
| if schema_issue: |
| escaped_issue = html.escape(schema_issue) |
| return ( |
| '<span class="validator-badge validator-warn">Check schema</span>' |
| f'<span class="validator-detail">{escaped_issue}</span>' |
| ) |
| return '<span class="validator-badge validator-ok">Valid SQL</span>' |
|
|
|
|
| |
| _CREATE_VERBS = ( |
| r"create|make|build|generate" |
| r"|criar|crie|cria|criando" |
| r"|gerar|gere|gera|gerando" |
| r"|faz|faca|fa\u00e7a|fazendo" |
| r"|monta|montar|monte" |
| r"|construa|construir|constroi" |
| r"|elabore|elaborar|elabora" |
| ) |
|
|
|
|
| def is_create_table_intent(message): |
| message = (message or "").strip().lower() |
| return bool( |
| re.search(rf"\b({_CREATE_VERBS})\b", message) |
| and re.search(r"\b(table|schema|tabela)\b", message) |
| ) |
|
|
|
|
| def is_rename_intent(message): |
| return bool(extract_renamed_columns(message)) |
|
|
|
|
| def is_table_edit_intent(message): |
| message = (message or "").strip().lower() |
| edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|exclui|novo|nova|troca|trocar|coloque|colocar|coloca|insira|insere|bota|tira|retire|retira|apaga|apague)\b" |
| direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar|coloca|acrescenta|insere|inserir|insira|bota|botar|bote)\b" |
| direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b" |
| target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b" |
| sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by", "soma", "media", "contagem", "maximo", "minimo"} |
| add_match = re.search(direct_add_terms, message) |
| if add_match: |
| after_add = message[add_match.start() + len(add_match.group()) :].strip() |
| first_word_after = after_add.split()[0] if after_add.split() else "" |
| is_add_intent = first_word_after not in sql_aggregation_terms |
| else: |
| is_add_intent = False |
| return bool( |
| is_add_intent |
| or re.search(direct_remove_terms, message) |
| or is_rename_intent(message) |
| or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message) |
| or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message))) |
| ) |
|
|
|
|
| def is_unsupported_data_mutation_intent(message): |
| message = normalize_text(message) |
| if not message: |
| return False |
| if is_create_table_intent(message): |
| return False |
| explicit_dml = ( |
| r"\b(?:insert\s+into|update\b.+\bset\b|delete\s+from|" |
| r"drop\s+table|drop\s+(?:the\s+|a\s+|an\s+)?\w+\s+table|truncate\s+table)\b" |
| ) |
| if re.search(explicit_dml, message): |
| return True |
| if re.search(r"\binsert\b", message) and not re.search(r"\b(column|field|coluna|campo)\b", message): |
| return True |
| row_terms = r"(?:row|rows|record|records|linha|linhas|registro|registros)" |
| row_mutation = ( |
| rf"\b(?:add|create|insert|include|remove|delete|drop|" |
| rf"adicionar|adicione|criar|crie|incluir|inclua|remover|remova|deletar|excluir|exclua|apagar|apague)\b" |
| rf".*\b{row_terms}\b" |
| ) |
| if re.search(row_mutation, message): |
| return True |
| destructive_all = ( |
| r"\b(?:delete|drop|remove|deletar|exclua|excluir|apaga|apagar|apague|remova|remover)\b" |
| r"\s+(?:all|every|rows?|records?|table|todos|todas|linhas?|registros?|tabela)\b" |
| ) |
| return bool(re.search(destructive_all, message)) |
|
|
|
|
| def _table_name_variants(table_name): |
| base = normalize_text(table_name) |
| if not base: |
| return set() |
| variants = {base} |
| if base.endswith("ies") and len(base) > 3: |
| variants.add(f"{base[:-3]}y") |
| elif base.endswith("s") and len(base) > 1: |
| variants.add(base[:-1]) |
| else: |
| variants.add(f"{base}s") |
| return {variant for variant in variants if variant} |
|
|
|
|
| def is_unsupported_data_mutation_for_schema(message, active_schema=""): |
| if is_unsupported_data_mutation_intent(message): |
| return True |
| table_name, _columns = parse_create_table_schema(create_table_from_schema(active_schema)) |
| if not table_name: |
| return False |
| message = normalize_text(message) |
| table_pattern = "|".join( |
| re.escape(variant) |
| for variant in sorted(_table_name_variants(table_name), key=len, reverse=True) |
| ) |
| table_mutation = ( |
| rf"\b(?:delete|drop|remove|deletar|excluir|exclua|remover|remova)\b" |
| rf"\s+(?:the\s+|a\s+|an\s+)?(?:table\s+)?(?:{table_pattern})\b(?:\s+table)?" |
| ) |
| return bool(re.search(table_mutation, message)) |
|
|
|
|
| def infer_column_type(column_name): |
| name = column_name.strip().lower() |
| if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}: |
| return "INTEGER" |
| if name in { |
| "salary", "price", "preco", "amount", "total", "grade", "peso", "weight", |
| "idade", "age", "altura", "height", "largura", "width", "comprimento", |
| "length", "desconto", "discount", |
| }: |
| return "NUMERIC" |
| if name in {"date", "created_at", "updated_at"} or name.endswith("_date"): |
| return "DATE" |
| return "TEXT" |
|
|
|
|
| def normalize_identifier(value): |
| identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_") |
| if not identifier: |
| return "" |
| if identifier[0].isdigit(): |
| identifier = f"col_{identifier}" |
| return identifier |
|
|
|
|
| def parse_column_definition(raw_column): |
| raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE) |
| raw_column = raw_column.strip(" .;:") |
| if not raw_column: |
| return None |
| type_matches = list( |
| re.finditer( |
| r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b", |
| raw_column, |
| flags=re.IGNORECASE, |
| ) |
| ) |
| explicit_type = type_matches[-1] if type_matches else None |
| if explicit_type: |
| name_part = raw_column[: explicit_type.start()].strip() |
| column_type = explicit_type.group(1).upper() |
| if column_type == "INT": |
| column_type = "INTEGER" |
| elif column_type == "BOOL": |
| column_type = "BOOLEAN" |
| elif column_type == "DECIMAL": |
| column_type = "NUMERIC" |
| elif column_type in {"FLOAT", "DOUBLE"}: |
| column_type = "REAL" |
| if not name_part.strip(): |
| column_type = None |
| name_part = raw_column |
| else: |
| name_part = raw_column |
| column_type = None |
| name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE) |
| column_name = normalize_identifier(name_part) |
| if not column_name: |
| return None |
| return column_name, column_type or infer_column_type(column_name) |
|
|
|
|
| def split_column_list(columns_text): |
| columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE) |
| parts = [] |
| type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b" |
| type_tokens = { |
| "integer", "int", "numeric", "decimal", "real", "float", "double", |
| "text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool", |
| } |
| stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"} |
| for part in (item.strip() for item in columns_text.split(",") if item.strip()): |
| tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()] |
| tokens = [token for token in tokens if token.lower() not in stopwords] |
| if not tokens: |
| continue |
| if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2: |
| index = 0 |
| inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"} |
| while index < len(tokens): |
| current = tokens[index] |
| next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else "" |
| if next_token in type_tokens and not ( |
| current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"} |
| ): |
| parts.append(f"{current} {tokens[index + 1]}") |
| index += 2 |
| else: |
| parts.append(current) |
| index += 1 |
| continue |
| if re.search(type_pattern, part, flags=re.IGNORECASE): |
| parts.append(part) |
| continue |
| if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens): |
| parts.extend(tokens) |
| else: |
| parts.append(part) |
| return parts |
|
|
|
|
| def format_create_table(table_name, columns): |
| if not table_name or not columns: |
| return "" |
| seen = set() |
| column_lines = [] |
| for column_name, column_type in columns: |
| if column_name in seen: |
| continue |
| seen.add(column_name) |
| column_lines.append(f" {column_name} {column_type}") |
| if not column_lines: |
| return "" |
| return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);" |
|
|
|
|
| def create_table_from_message(message): |
| message = (message or "").strip() |
| patterns = ( |
| r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$", |
| rf"\b(?:{_CREATE_VERBS})\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| table_name = normalize_identifier(match.group(1)) |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) |
| if parsed |
| ] |
| return format_create_table(table_name, columns) |
| return "" |
|
|
|
|
| def parse_create_table_schema(schema): |
| schema = (schema or "").strip() |
| match = re.match( |
| r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$", |
| schema, |
| flags=re.IGNORECASE | re.DOTALL, |
| ) |
| if not match: |
| return "", [] |
| table_name = normalize_identifier(match.group(1)) |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) |
| if parsed |
| ] |
| return table_name, columns |
|
|
|
|
| def _schema_column_names(schema): |
| _table_name, columns = parse_create_table_schema(schema) |
| return [name for name, _column_type in columns] |
|
|
|
|
| def _canonical_column(name): |
| normalized = normalize_text(name) |
| for canonical, aliases in COLUMN_ALIASES.items(): |
| if normalized == canonical or normalized in aliases: |
| return canonical |
| return normalized |
|
|
|
|
| _NO_COLUMN_DEFAULT = object() |
|
|
|
|
| def _column_from_message(normalized_message, schema_columns, canonical, default=_NO_COLUMN_DEFAULT): |
| aliases = COLUMN_ALIASES.get(canonical, {canonical}) |
| if not any(re.search(rf"(?<!\w){re.escape(alias)}(?!\w)", normalized_message) for alias in aliases): |
| return None if default is _NO_COLUMN_DEFAULT else default |
| for column in schema_columns: |
| if _canonical_column(column) == canonical: |
| return column |
| return canonical if default is _NO_COLUMN_DEFAULT else default |
|
|
|
|
| NUMERIC_COLUMN_CANONICALS = {"price", "salary", "grade", "amount", "total", "quantity", "stock", "weight", "year"} |
|
|
|
|
| def _first_numeric_focus(normalized_message, schema_columns, default="value"): |
| for canonical in ("price", "salary", "grade", "amount", "total", "quantity", "stock", "weight", "year"): |
| selected = _column_from_message(normalized_message, schema_columns, canonical, None) |
| if selected: |
| return selected |
| for column in schema_columns: |
| if _canonical_column(column) in NUMERIC_COLUMN_CANONICALS: |
| return column |
| return default |
|
|
|
|
| def _schema_column_for_canonical(schema_columns, canonical): |
| for column in schema_columns: |
| if _canonical_column(column) == canonical: |
| return column |
| return "" |
|
|
|
|
| def _schema_has_column(schema_columns, column_name): |
| return any(column == column_name for column in schema_columns) |
|
|
|
|
| def _mentioned_schema_column(normalized_message, schema_columns): |
| for column in schema_columns: |
| normalized_column = normalize_text(column) |
| canonical = _canonical_column(column) |
| aliases = {normalized_column, canonical, *COLUMN_ALIASES.get(canonical, set())} |
| if any(re.search(rf"(?<!\w){re.escape(alias)}(?!\w)", normalized_message) for alias in aliases): |
| return column |
| return "" |
|
|
|
|
| def _dimension_after_group_marker(normalized_message, schema_columns): |
| match = re.search(r"\b(?:by|por|par|per)\s+([A-Za-z_][\w]*)", normalized_message) |
| if not match: |
| return "" |
| token = normalize_text(match.group(1)) |
| for canonical, aliases in COLUMN_ALIASES.items(): |
| if token == canonical or token in aliases: |
| return _column_from_message(token, schema_columns, canonical, canonical) |
| return normalize_identifier(token) |
|
|
|
|
| def _table_subject(schema): |
| table_name, _columns = parse_create_table_schema(schema) |
| normalized = normalize_text(table_name) |
| for canonical, aliases in TABLE_ALIASES.items(): |
| if normalized == canonical or normalized in aliases: |
| return canonical[:-1] if canonical.endswith("s") else canonical |
| if normalized.endswith("s") and len(normalized) > 3: |
| return normalized[:-1] |
| return normalized or "row" |
|
|
|
|
| def _has_marker(normalized_message, markers): |
| return any(marker in normalized_message for marker in markers) |
|
|
|
|
| def _clean_english_question(text): |
| text = re.sub(r"\s+", " ", text or "").strip(" .;:") |
| if not text: |
| return LANGUAGE_FALLBACK_QUESTION |
| return text if text.endswith("?") else f"{text}?" |
|
|
|
|
| def _lexical_translate_sql_question(normalized_message): |
| text = normalized_message |
| for pattern, replacement in PHRASE_TRANSLATIONS: |
| text = re.sub(pattern, replacement, text) |
| words = [] |
| for raw_word in re.findall(r"[A-Za-z_][\w]*|\d+(?:\.\d+)?|[<>=]+", text): |
| normalized_word = normalize_text(raw_word) |
| translated = TOKEN_TRANSLATIONS.get(normalized_word) |
| if not translated: |
| for canonical, aliases in COLUMN_ALIASES.items(): |
| if normalized_word == canonical or normalized_word in aliases: |
| translated = canonical |
| break |
| if not translated: |
| for canonical, aliases in TABLE_ALIASES.items(): |
| if normalized_word == canonical or normalized_word in aliases: |
| translated = canonical |
| break |
| if translated: |
| words.extend(translated.split()) |
| else: |
| words.append(raw_word) |
| return " ".join(words) |
|
|
|
|
| def normalize_sql_question_to_english(message, schema=""): |
| raw_message = (message or "").strip() |
| normalized_message = normalize_text(raw_message) |
| if not normalized_message: |
| return LANGUAGE_FALLBACK_QUESTION |
|
|
| schema_columns = _schema_column_names(schema) |
| subject = _table_subject(schema) |
|
|
| if _has_marker(normalized_message, ("mais caro", "mas caro", "plus cher", "maior preco", "highest price", "most expensive")): |
| return _clean_english_question(f"what is the most expensive {subject}") |
|
|
| if _has_marker(normalized_message, ("mais barato", "mas barato", "menor preco", "lowest price", "cheapest")): |
| return _clean_english_question(f"what is the cheapest {subject}") |
|
|
| if _has_marker(normalized_message, ("media", "moyenne", "promedio", "average")): |
| metric = _first_numeric_focus(normalized_message, schema_columns) |
| dimension = _dimension_after_group_marker(normalized_message, schema_columns) |
| if dimension: |
| return _clean_english_question(f"what is the average {metric} by {dimension}") |
| return _clean_english_question(f"what is the average {metric}") |
|
|
| if _has_marker(normalized_message, ("quantos", "cuantos", "combien", "count", "how many")): |
| dimension = _dimension_after_group_marker(normalized_message, schema_columns) |
| if dimension: |
| return _clean_english_question(f"count {subject}s by {dimension}") |
| return _clean_english_question(f"how many {subject}s are there") |
|
|
| translated = _lexical_translate_sql_question(normalized_message) |
| has_foreign_marker = any( |
| re.search(rf"(?<!\w){re.escape(marker)}(?!\w)", normalized_message) |
| for marker in FOREIGN_SQL_MARKERS |
| ) |
| has_ascii_words = bool(re.search(r"[A-Za-z]", translated)) |
| if has_foreign_marker and has_ascii_words: |
| return _clean_english_question(translated) |
|
|
| if has_ascii_words: |
| return raw_message |
|
|
| return LANGUAGE_FALLBACK_QUESTION |
|
|
|
|
| def deterministic_sql_query(message, schema=""): |
| table_name, columns = parse_create_table_schema(schema) |
| schema_columns = [name for name, _column_type in columns] |
| if not table_name or not schema_columns: |
| return "" |
|
|
| normalized_message = normalize_text(message) |
| english_question = normalize_sql_question_to_english(message, schema) |
| normalized_english = normalize_text(english_question) |
|
|
| if _has_marker(normalized_message, ("mais caro", "mas caro", "plus cher", "maior preco")) or _has_marker( |
| normalized_english, ("most expensive", "highest price") |
| ): |
| metric = _schema_column_for_canonical(schema_columns, "price") or _schema_column_for_canonical(schema_columns, "amount") |
| if metric and _schema_has_column(schema_columns, metric): |
| return f"SELECT * FROM {table_name} ORDER BY {metric} DESC LIMIT 1;" |
|
|
| if _has_marker(normalized_message, ("mais barato", "mas barato", "menor preco")) or _has_marker( |
| normalized_english, ("cheapest", "lowest price") |
| ): |
| metric = _schema_column_for_canonical(schema_columns, "price") or _schema_column_for_canonical(schema_columns, "amount") |
| if metric and _schema_has_column(schema_columns, metric): |
| return f"SELECT * FROM {table_name} ORDER BY {metric} ASC LIMIT 1;" |
|
|
| if _has_marker(normalized_message, ("media", "moyenne", "promedio")) or "average" in normalized_english: |
| metric = _first_numeric_focus(normalized_message, schema_columns, "") |
| dimension = _dimension_after_group_marker(normalized_message, schema_columns) |
| if metric and dimension and _schema_has_column(schema_columns, metric) and _schema_has_column(schema_columns, dimension): |
| return ( |
| f"SELECT {dimension}, AVG({metric}) AS average_{metric} " |
| f"FROM {table_name} GROUP BY {dimension};" |
| ) |
| if metric and _schema_has_column(schema_columns, metric): |
| return f"SELECT AVG({metric}) AS average_{metric} FROM {table_name};" |
|
|
| if _has_marker(normalized_message, ("quantos", "cuantos", "combien")) or "how many" in normalized_english: |
| dimension = _dimension_after_group_marker(normalized_message, schema_columns) |
| if dimension and _schema_has_column(schema_columns, dimension): |
| return f"SELECT {dimension}, COUNT(*) AS row_count FROM {table_name} GROUP BY {dimension};" |
| return f"SELECT COUNT(*) AS row_count FROM {table_name};" |
|
|
| comparison_match = re.search( |
| r"\b([A-Za-z_][\w]*)\s+(greater than|less than)\s+(\d+(?:\.\d+)?)\b", |
| normalized_english, |
| ) |
| if comparison_match: |
| column = _canonical_column(comparison_match.group(1)) |
| actual_column = _schema_column_for_canonical(schema_columns, column) |
| if actual_column: |
| operator = ">" if comparison_match.group(2) == "greater than" else "<" |
| return f"SELECT * FROM {table_name} WHERE {actual_column} {operator} {comparison_match.group(3)};" |
|
|
| ranking_column = _mentioned_schema_column(normalized_message, schema_columns) |
| if ranking_column and _canonical_column(ranking_column) in NUMERIC_COLUMN_CANONICALS: |
| if ( |
| _has_marker(normalized_message, ("maior", "maximo", "mais alto")) |
| or _has_marker(normalized_english, ("highest", "maximum", "max")) |
| ) and "maior que" not in normalized_message: |
| return f"SELECT * FROM {table_name} ORDER BY {ranking_column} DESC LIMIT 1;" |
| if ( |
| _has_marker(normalized_message, ("menor", "minimo", "mais baixo")) |
| or _has_marker(normalized_english, ("lowest", "minimum", "min")) |
| ) and "menor que" not in normalized_message: |
| return f"SELECT * FROM {table_name} ORDER BY {ranking_column} ASC LIMIT 1;" |
|
|
| if normalized_english in {"list all", "show all", "select all rows", "list all rows", "show all rows"}: |
| return f"SELECT * FROM {table_name};" |
|
|
| return "" |
|
|
|
|
| def create_table_from_schema(schema): |
| table_name, columns = parse_create_table_schema(schema) |
| return format_create_table(table_name, columns) |
|
|
|
|
| def extract_create_table_statement(text): |
| cleaned = extract_sql_candidate(text) |
| match = re.search( |
| r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?", |
| cleaned, |
| flags=re.IGNORECASE | re.DOTALL, |
| ) |
| return clean_generation(match.group(0)) if match else "" |
|
|
|
|
| def last_create_table_from_history(chat_history): |
| for item in reversed(list(chat_history or [])): |
| if not isinstance(item, dict) or item.get("role") != "assistant": |
| continue |
| statement = extract_create_table_statement(item.get("content", "")) |
| if statement: |
| return statement |
| return "" |
|
|
|
|
| def extract_added_columns(message): |
| message = (message or "").strip() |
| patterns = ( |
| r":\s*(.+)$", |
| r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| columns = [ |
| parsed |
| for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1))) |
| if parsed |
| ] |
| if columns: |
| return columns |
| return [] |
|
|
|
|
| def extract_removed_columns(message): |
| message = (message or "").strip() |
| patterns = ( |
| r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", |
| ) |
| for pattern in patterns: |
| match = re.search(pattern, message, flags=re.IGNORECASE) |
| if not match: |
| continue |
| columns = [normalize_identifier(column) for column in split_column_list(match.group(1))] |
| columns = [column for column in columns if column] |
| if columns: |
| return columns |
| return [] |
|
|
|
|
| def extract_renamed_columns(message): |
| pattern = ( |
| r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|alterar|altera|mude|mudar|muda|edita|editar)\s+" |
| r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)" |
| ) |
| matches = re.findall(pattern, message or "", flags=re.IGNORECASE) |
| troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE) |
| invalid_old_names = {"ela", "ele", "isso", "isto", "essa", "esse", "this", "it"} |
| invalid_new_names = {"ter", "have", "having", "tambem", "tambem"} |
| return [ |
| (normalize_identifier(old), normalize_identifier(new)) |
| for old, new in [*matches, *troca_matches] |
| if normalize_identifier(old) |
| and normalize_identifier(new) |
| and normalize_identifier(old) not in invalid_old_names |
| and normalize_identifier(new) not in invalid_new_names |
| ] |
|
|
|
|
| def parse_compound_edit(message): |
| segment_pattern = ( |
| r"\s+(?:and|e)\s+" |
| r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|" |
| r"adicione|adicionar|adicionando|inclua|incluir|acrescente|" |
| r"coloca|coloque|bota|insira|insere|" |
| r"remova|remover|deletar|exclua|excluir|exclui|" |
| r"tira|tirar|tire|retira|retire|apaga|apague|" |
| r"renomeie|renomear|renomeia|altere|alterar|altera|" |
| r"mude|mudar|muda|edita|editar|troca|trocar)\b)" |
| ) |
| segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE) |
| added, removed, renamed = [], [], [] |
| for seg in segments: |
| seg = seg.strip() |
| if not seg: |
| continue |
| if is_rename_intent(seg): |
| renamed.extend(extract_renamed_columns(seg)) |
| elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE): |
| removed.extend(extract_removed_columns(seg)) |
| else: |
| cols = extract_added_columns(seg) |
| if cols: |
| added.extend(cols) |
| return added, removed, renamed |
|
|
|
|
| def edit_create_table_from_message(message, chat_history, active_schema): |
| if not is_table_edit_intent(message) and not is_rename_intent(message): |
| return "" |
| base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema) |
| table_name, existing_columns = parse_create_table_schema(base_sql) |
| if not table_name: |
| return "" |
| added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message) |
| removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list} |
| if not added_columns and not removed_set and not renamed_columns: |
| return "" |
| rename_map = dict(renamed_columns) |
| kept_columns = [ |
| (rename_map.get(col_name, col_name), col_type) |
| for col_name, col_type in existing_columns |
| if col_name not in removed_set |
| ] |
| updated_columns = [*kept_columns, *added_columns] |
| if updated_columns == existing_columns: |
| return "" |
| return format_create_table(table_name, updated_columns) |
|
|
|
|
| def create_table_from_suggestion(suggestion): |
| if not suggestion: |
| return "" |
| if isinstance(suggestion, dict): |
| table_name = suggestion.get("table_name") |
| columns = [ |
| (column.get("name"), column.get("type", "TEXT")) |
| for column in suggestion.get("columns", []) |
| if isinstance(column, dict) |
| ] |
| else: |
| table_name = getattr(suggestion, "table_name", "") |
| columns = getattr(suggestion, "columns", ()) |
| parsed = [] |
| for name, column_type in columns: |
| identifier = normalize_identifier(name) |
| if identifier: |
| parsed.append((identifier, (column_type or "TEXT").upper())) |
| return format_create_table(normalize_identifier(table_name), parsed) |
|
|