import html import re import unicodedata import sqlparse SQL_STARTERS = {"SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP"} LANGUAGE_FALLBACK_QUESTION = "write a SELECT query for this table." FOREIGN_SQL_MARKERS = { "alumnos", "alunos", "barato", "caro", "clientes", "combien", "conte", "contar", "cuantos", "departamento", "empleados", "estoque", "funcionarios", "liste", "listar", "maior", "mas", "media", "menor", "mostre", "mostrar", "moyenne", "ordenado", "ordene", "par", "por", "preco", "precio", "produit", "produits", "producto", "productos", "produto", "produtos", "qual", "quantos", "salaire", "salario", "soma", "vendas", } TABLE_ALIASES = { "employees": {"employee", "employees", "funcionario", "funcionarios", "empleado", "empleados"}, "orders": {"order", "orders", "pedido", "pedidos"}, "students": {"student", "students", "aluno", "alunos", "alumno", "alumnos", "etudiant", "etudiants"}, "products": {"product", "products", "produto", "produtos", "producto", "productos", "produit", "produits"}, "sales": {"sale", "sales", "venda", "vendas", "venta", "ventas"}, "customers": {"customer", "customers", "cliente", "clientes"}, } COLUMN_ALIASES = { "amount": {"amount", "valor", "importe", "montant"}, "category": {"category", "categoria", "categorie"}, "course": {"course", "curso"}, "customer_id": {"customer_id", "cliente_id", "id_cliente"}, "date": {"date", "data", "fecha"}, "department": {"department", "departamento", "departement"}, "grade": {"grade", "nota", "calificacion"}, "name": {"name", "nome", "nombre", "nom"}, "price": {"price", "preco", "precio", "prix"}, "product": {"product", "produto", "producto", "produit"}, "product_id": {"product_id", "produto_id", "producto_id", "id_produto"}, "quantity": {"quantity", "quantidade", "cantidad", "quantite"}, "salary": {"salary", "salario", "salaire", "sueldo"}, "stock": {"stock", "estoque", "inventario"}, "total": {"total", "soma", "sum"}, "weight": {"weight", "peso", "poids"}, "year": {"year", "ano", "anio", "annee"}, } PHRASE_TRANSLATIONS = ( (r"\bmaior que\b", "greater than"), (r"\bmenor que\b", "less than"), (r"\bmaior ou igual a\b", "greater than or equal to"), (r"\bmenor ou igual a\b", "less than or equal to"), (r"\bmais caro\b", "most expensive"), (r"\bmas caro\b", "most expensive"), (r"\bplus cher\b", "most expensive"), (r"\bpreco mais alto\b", "highest price"), (r"\bprecio mas alto\b", "highest price"), (r"\bmaior preco\b", "highest price"), (r"\bmais barato\b", "cheapest"), (r"\bmas barato\b", "cheapest"), (r"\bplus bas prix\b", "cheapest"), (r"\bmenor preco\b", "lowest price"), (r"\bmaior para menor\b", "descending"), (r"\bmenor para maior\b", "ascending"), ) TOKEN_TRANSLATIONS = { "a": "the", "agrupe": "group", "alunos": "students", "barato": "cheap", "caro": "expensive", "com": "with", "conte": "count", "contar": "count", "cuantos": "how many", "da": "of", "das": "of", "de": "of", "do": "of", "dos": "of", "e": "and", "em": "in", "filtre": "filter", "filtrar": "filter", "funcionarios": "employees", "liste": "list", "listar": "list", "maior": "highest", "mas": "more", "media": "average", "menor": "lowest", "mostre": "show", "mostrar": "show", "moyenne": "average", "o": "the", "os": "the", "ordene": "order", "ordenado": "ordered", "par": "by", "para": "for", "por": "by", "produto": "product", "produtos": "products", "peso": "weight", "qual": "what", "quantos": "how many", "soma": "sum", "some": "sum", "todos": "all", } def content_to_text(value): if value is None: return "" if isinstance(value, str): return value if isinstance(value, dict): for key in ("text", "content", "value"): if key in value: return content_to_text(value[key]) return " ".join(content_to_text(item) for item in value.values()) if isinstance(value, (list, tuple)): return "\n".join(content_to_text(item) for item in value) return str(value) def normalize_text(value): text = content_to_text(value).lower() text = unicodedata.normalize("NFKD", text) text = "".join(char for char in text if not unicodedata.combining(char)) return re.sub(r"\s+", " ", text).strip() def clean_generation(text): cleaned = content_to_text(text).strip() if cleaned.startswith("```"): lines = cleaned.splitlines() if lines and lines[0].strip().lower() in {"```", "```sql"}: lines = lines[1:] if lines and lines[-1].strip() == "```": lines = lines[:-1] cleaned = "\n".join(lines).strip() for marker in ("<|end|>", "<|user|>", "<|assistant|>", ""): if marker in cleaned: cleaned = cleaned.split(marker, 1)[0].strip() if cleaned.upper().startswith("SQL:"): cleaned = cleaned[4:].strip() return cleaned def extract_sql_candidate(text): cleaned = clean_generation(text) match = re.search(r"\b(SELECT|WITH|INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", cleaned, flags=re.IGNORECASE) if not match: return cleaned return cleaned[match.start() :].strip() def is_sql_like(text): text = (text or "").strip() if not text: return False first_word = re.match(r"^\s*([A-Za-z]+)", text) if not first_word: return False return first_word.group(1).upper() in SQL_STARTERS def is_sql_intent(message, schema=""): message = normalize_text(message) if not message: return False smalltalk_patterns = { "oi", "ola", "olá", "hi", "hello", "hey", "obrigado", "obrigada", "thanks", "thank you", "como voce esta", "como você esta", "qual seu nome", "me conte uma piada", "vamos conversar", "como voce funciona", "como funciona", "o que voce faz", "o que faz", } if message in {normalize_text(item) for item in smalltalk_patterns}: return False if any(pattern in message for pattern in ("como voce esta", "qual seu nome", "conte uma piada")): return False sql_terms = { "all", "average", "count", "columns", "database", "find", "get", "group by", "highest", "join", "least", "list", "lowest", "max", "maximum", "min", "minimum", "most", "most expensive", "order by", "query", "rows", "select", "show", "sum", "where", "consulta", "consultar", "contar", "colunas", "linhas", "liste", "listar", "maior", "mais caro", "menor", "media", "mostre", "mostrar", "ordene", "selecione", "some", "soma", "quantos", "filtre", "filtrar", } if any(re.search(rf"(?|!=|<=|>=|<|>)\s*\"(?:\"\"|[^\"])*\"", r"\1 ", sql_text or "", ) return re.sub( r"'(?:''|[^'])*'|\b\d+(?:\.\d+)?\b", " ", without_quoted_values, ) def _identifier_names(sql_text): try: statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()] except Exception: return [] names = [] for statement in statements: flattened = list(statement.flatten()) for index, token in enumerate(flattened): if token.ttype not in sqlparse.tokens.Name and token.ttype is not sqlparse.tokens.Literal.String.Symbol: continue previous_value = "" next_value = "" for previous in reversed(flattened[:index]): if not previous.is_whitespace: previous_value = previous.value break for next_token in flattened[index + 1:]: if not next_token.is_whitespace: next_value = next_token.value break names.append((token.value.strip('"'), previous_value, next_value)) return names def sql_schema_validation_issue(sql_text, schema): table_name, columns = parse_create_table_schema(schema) if not table_name or not columns: return "" expected_table = table_name.lower() allowed_columns = {name.lower() for name, _column_type in columns} scrubbed_sql = _without_sql_literals(sql_text) cte_aliases = { match.group(1).lower() for match in re.finditer( r"(?:WITH|,)\s+([A-Za-z_][\w]*)\s+AS\s*\(", scrubbed_sql, flags=re.IGNORECASE, ) } table_refs = [ match.group(1).lower() for match in re.finditer( r"\b(?:FROM|JOIN)\s+([A-Za-z_][\w]*)", scrubbed_sql, flags=re.IGNORECASE, ) ] if not table_refs: return f"Model SQL does not reference active table: {table_name}" for table_ref in table_refs: if table_ref not in {expected_table, *cte_aliases}: return f"Unknown table for active schema: {table_ref}" table_aliases = set() table_alias_pattern = ( r"\b(?:FROM|JOIN)\s+" rf"({re.escape(table_name)})" r"\s+(?:AS\s+)?([A-Za-z_][\w]*)" ) for match in re.finditer(table_alias_pattern, scrubbed_sql, flags=re.IGNORECASE): alias = match.group(2) if alias.upper() not in SQL_ALIAS_STOPWORDS: table_aliases.add(alias.lower()) identifier_sql = _without_sql_value_literals(sql_text) output_aliases = { (match.group(1) or match.group(2)).lower() for match in re.finditer( r"\bAS\s+(?:\"([^\"]+)\"|([A-Za-z_][\w]*))", identifier_sql, flags=re.IGNORECASE, ) } allowed_non_columns = { expected_table, *cte_aliases, *table_aliases, *output_aliases, } for name, _previous_value, next_value in _identifier_names(identifier_sql): normalized_name = name.lower() if next_value == ".": continue if normalized_name in allowed_non_columns: continue if name.upper() in SQL_SCHEMA_FUNCTION_NAMES: continue if normalized_name in allowed_columns: continue return f"Unknown column for active schema: {name}" return "" def validate_sql(sql_text, schema=""): sql_text = (sql_text or "").strip() if not sql_text: return 'No SQL yet' try: statements = [stmt for stmt in sqlparse.parse(sql_text) if str(stmt).strip()] except Exception as exc: error_type = html.escape(type(exc).__name__) return ( 'Check syntax' f'sqlparse error: {error_type}' ) if not statements: return ( 'Check syntax' 'No parsed SQL statement.' ) first_token = statements[0].token_first(skip_cm=True) token_value = first_token.value.strip().upper() if first_token is not None else "UNKNOWN" if token_value not in SQL_STARTERS: escaped_token = html.escape(token_value) return ( 'Check syntax' f'First token: {escaped_token}' ) trailing_keyword = re.search( r"\b(AND|BY|FROM|GROUP|HAVING|JOIN|LIMIT|NOT|ON|OR|ORDER|SELECT|WHERE)\s*;?\s*$", sql_text, flags=re.IGNORECASE, ) if trailing_keyword: escaped_token = html.escape(trailing_keyword.group(1).upper()) return ( 'Check syntax' f'Incomplete trailing clause: {escaped_token}' ) bare_negated_predicate = re.search( r"\b(AND|HAVING|OR|WHERE)\s+NOT\s+([A-Za-z_][\w.]*)\s*;?\s*$", sql_text, flags=re.IGNORECASE, ) if bare_negated_predicate: escaped_token = html.escape(bare_negated_predicate.group(2)) return ( 'Check syntax' f'Incomplete negated predicate: NOT {escaped_token}' ) trailing_comparison = re.search(r"(=|<>|!=|<=|>=|<|>)\s*;?\s*$", sql_text) if trailing_comparison: escaped_token = html.escape(trailing_comparison.group(1)) return ( 'Check syntax' f'Incomplete comparison operator: {escaped_token}' ) schema_issue = sql_schema_validation_issue(sql_text, schema) if schema_issue: escaped_issue = html.escape(schema_issue) return ( 'Check schema' f'{escaped_issue}' ) return 'Valid SQL' # Verb stems for create-table intent — keep in sync with create_table_from_message below. _CREATE_VERBS = ( r"create|make|build|generate" r"|criar|crie|cria|criando" r"|gerar|gere|gera|gerando" r"|faz|faca|fa\u00e7a|fazendo" r"|monta|montar|monte" r"|construa|construir|constroi" r"|elabore|elaborar|elabora" ) def is_create_table_intent(message): message = (message or "").strip().lower() return bool( re.search(rf"\b({_CREATE_VERBS})\b", message) and re.search(r"\b(table|schema|tabela)\b", message) ) def is_rename_intent(message): return bool(extract_renamed_columns(message)) def is_table_edit_intent(message): message = (message or "").strip().lower() edit_terms = r"\b(edit|update|modify|alter|add|include|remove|delete|drop|edita|editar|altera|altere|alterar|mude|mudar|adicione|adicionar|inclua|incluir|acrescente|remova|remover|deletar|exclua|excluir|exclui|novo|nova|troca|trocar|coloque|colocar|coloca|insira|insere|bota|tira|retire|retira|apaga|apague)\b" direct_add_terms = r"\b(add|include|adicione|adicionar|adicionando|inclua|incluir|acrescente|coloque|colocar|coloca|acrescenta|insere|inserir|insira|bota|botar|bote)\b" direct_remove_terms = r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b" target_terms = r"\b(column|field|element|coluna|campo|elemento|item)\b" sql_aggregation_terms = {"up", "sum", "total", "count", "average", "avg", "max", "min", "by", "soma", "media", "contagem", "maximo", "minimo"} add_match = re.search(direct_add_terms, message) if add_match: after_add = message[add_match.start() + len(add_match.group()) :].strip() first_word_after = after_add.split()[0] if after_add.split() else "" is_add_intent = first_word_after not in sql_aggregation_terms else: is_add_intent = False return bool( is_add_intent or re.search(direct_remove_terms, message) or is_rename_intent(message) or re.search(r"\b(?:altere|alterar|mude|mudar)\b.*\bter\b", message) or (re.search(edit_terms, message) and (re.search(target_terms, message) or ":" in message or re.search(r"\bpor\b", message))) ) def is_unsupported_data_mutation_intent(message): message = normalize_text(message) if not message: return False if is_create_table_intent(message): return False explicit_dml = ( r"\b(?:insert\s+into|update\b.+\bset\b|delete\s+from|" r"drop\s+table|drop\s+(?:the\s+|a\s+|an\s+)?\w+\s+table|truncate\s+table)\b" ) if re.search(explicit_dml, message): return True if re.search(r"\binsert\b", message) and not re.search(r"\b(column|field|coluna|campo)\b", message): return True row_terms = r"(?:row|rows|record|records|linha|linhas|registro|registros)" row_mutation = ( rf"\b(?:add|create|insert|include|remove|delete|drop|" rf"adicionar|adicione|criar|crie|incluir|inclua|remover|remova|deletar|excluir|exclua|apagar|apague)\b" rf".*\b{row_terms}\b" ) if re.search(row_mutation, message): return True destructive_all = ( r"\b(?:delete|drop|remove|deletar|exclua|excluir|apaga|apagar|apague|remova|remover)\b" r"\s+(?:all|every|rows?|records?|table|todos|todas|linhas?|registros?|tabela)\b" ) return bool(re.search(destructive_all, message)) def _table_name_variants(table_name): base = normalize_text(table_name) if not base: return set() variants = {base} if base.endswith("ies") and len(base) > 3: variants.add(f"{base[:-3]}y") elif base.endswith("s") and len(base) > 1: variants.add(base[:-1]) else: variants.add(f"{base}s") return {variant for variant in variants if variant} def is_unsupported_data_mutation_for_schema(message, active_schema=""): if is_unsupported_data_mutation_intent(message): return True table_name, _columns = parse_create_table_schema(create_table_from_schema(active_schema)) if not table_name: return False message = normalize_text(message) table_pattern = "|".join( re.escape(variant) for variant in sorted(_table_name_variants(table_name), key=len, reverse=True) ) table_mutation = ( rf"\b(?:delete|drop|remove|deletar|excluir|exclua|remover|remova)\b" rf"\s+(?:the\s+|a\s+|an\s+)?(?:table\s+)?(?:{table_pattern})\b(?:\s+table)?" ) return bool(re.search(table_mutation, message)) def infer_column_type(column_name): name = column_name.strip().lower() if name == "id" or name.endswith("_id") or name in {"quantity", "quantidade", "stock", "estoque", "year"}: return "INTEGER" if name in { "salary", "price", "preco", "amount", "total", "grade", "peso", "weight", "idade", "age", "altura", "height", "largura", "width", "comprimento", "length", "desconto", "discount", }: return "NUMERIC" if name in {"date", "created_at", "updated_at"} or name.endswith("_date"): return "DATE" return "TEXT" def normalize_identifier(value): identifier = re.sub(r"\W+", "_", normalize_text(value)).strip("_") if not identifier: return "" if identifier[0].isdigit(): identifier = f"col_{identifier}" return identifier def parse_column_definition(raw_column): raw_column = re.sub(r"\b(for me|please|por favor)\b", "", raw_column or "", flags=re.IGNORECASE) raw_column = raw_column.strip(" .;:") if not raw_column: return None type_matches = list( re.finditer( r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b", raw_column, flags=re.IGNORECASE, ) ) explicit_type = type_matches[-1] if type_matches else None if explicit_type: name_part = raw_column[: explicit_type.start()].strip() column_type = explicit_type.group(1).upper() if column_type == "INT": column_type = "INTEGER" elif column_type == "BOOL": column_type = "BOOLEAN" elif column_type == "DECIMAL": column_type = "NUMERIC" elif column_type in {"FLOAT", "DOUBLE"}: column_type = "REAL" if not name_part.strip(): column_type = None name_part = raw_column else: name_part = raw_column column_type = None name_part = re.sub(r"\b(column|field|coluna|campo)\b", "", name_part, flags=re.IGNORECASE) column_name = normalize_identifier(name_part) if not column_name: return None return column_name, column_type or infer_column_type(column_name) def split_column_list(columns_text): columns_text = re.sub(r"\s+(and|e)\s+", ",", columns_text or "", flags=re.IGNORECASE) parts = [] type_pattern = r"\b(integer|int|numeric|decimal|real|float|double|text|varchar|char|date|datetime|timestamp|boolean|bool)\b" type_tokens = { "integer", "int", "numeric", "decimal", "real", "float", "double", "text", "varchar", "char", "date", "datetime", "timestamp", "boolean", "bool", } stopwords = {"to", "from", "into", "as", "for", "o", "a", "os", "de", "do", "da", "dos", "das"} for part in (item.strip() for item in columns_text.split(",") if item.strip()): tokens = [token.strip() for token in re.split(r"\s+", part) if token.strip()] tokens = [token for token in tokens if token.lower() not in stopwords] if not tokens: continue if re.search(type_pattern, part, flags=re.IGNORECASE) and len(tokens) > 2: index = 0 inferrable_names = {"total", "date", "time", "timestamp", "int", "text", "real", "char"} while index < len(tokens): current = tokens[index] next_token = tokens[index + 1].lower() if index + 1 < len(tokens) else "" if next_token in type_tokens and not ( current.lower() in inferrable_names and next_token in {"date", "datetime", "timestamp"} ): parts.append(f"{current} {tokens[index + 1]}") index += 2 else: parts.append(current) index += 1 continue if re.search(type_pattern, part, flags=re.IGNORECASE): parts.append(part) continue if len(tokens) > 1 and all(re.match(r"^[A-Za-z_][\wÀ-ÿ]*$", token) for token in tokens): parts.extend(tokens) else: parts.append(part) return parts def format_create_table(table_name, columns): if not table_name or not columns: return "" seen = set() column_lines = [] for column_name, column_type in columns: if column_name in seen: continue seen.add(column_name) column_lines.append(f" {column_name} {column_type}") if not column_lines: return "" return f"CREATE TABLE {table_name} (\n" + ",\n".join(column_lines) + "\n);" def create_table_from_message(message): message = (message or "").strip() patterns = ( r"\b(?:table|tabela)\s+(?:called\s+|named\s+|chamada?\s+|nomeada?\s+)?([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$", rf"\b(?:{_CREATE_VERBS})\b.*?\b(?:table|tabela)\b\s+([A-Za-z_][\w]*)\s+(?:with|containing|including|com|tendo|contendo)\s+(.+)$", ) for pattern in patterns: match = re.search(pattern, message, flags=re.IGNORECASE) if not match: continue table_name = normalize_identifier(match.group(1)) columns = [ parsed for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) if parsed ] return format_create_table(table_name, columns) return "" def parse_create_table_schema(schema): schema = (schema or "").strip() match = re.match( r"^\s*(?:CREATE\s+TABLE\s+)?([A-Za-z_][\w]*)\s*\((.*?)\)\s*;?\s*$", schema, flags=re.IGNORECASE | re.DOTALL, ) if not match: return "", [] table_name = normalize_identifier(match.group(1)) columns = [ parsed for parsed in (parse_column_definition(column) for column in split_column_list(match.group(2))) if parsed ] return table_name, columns def _schema_column_names(schema): _table_name, columns = parse_create_table_schema(schema) return [name for name, _column_type in columns] def _canonical_column(name): normalized = normalize_text(name) for canonical, aliases in COLUMN_ALIASES.items(): if normalized == canonical or normalized in aliases: return canonical return normalized _NO_COLUMN_DEFAULT = object() def _column_from_message(normalized_message, schema_columns, canonical, default=_NO_COLUMN_DEFAULT): aliases = COLUMN_ALIASES.get(canonical, {canonical}) if not any(re.search(rf"(? 3: return normalized[:-1] return normalized or "row" def _has_marker(normalized_message, markers): return any(marker in normalized_message for marker in markers) def _clean_english_question(text): text = re.sub(r"\s+", " ", text or "").strip(" .;:") if not text: return LANGUAGE_FALLBACK_QUESTION return text if text.endswith("?") else f"{text}?" def _lexical_translate_sql_question(normalized_message): text = normalized_message for pattern, replacement in PHRASE_TRANSLATIONS: text = re.sub(pattern, replacement, text) words = [] for raw_word in re.findall(r"[A-Za-z_][\w]*|\d+(?:\.\d+)?|[<>=]+", text): normalized_word = normalize_text(raw_word) translated = TOKEN_TRANSLATIONS.get(normalized_word) if not translated: for canonical, aliases in COLUMN_ALIASES.items(): if normalized_word == canonical or normalized_word in aliases: translated = canonical break if not translated: for canonical, aliases in TABLE_ALIASES.items(): if normalized_word == canonical or normalized_word in aliases: translated = canonical break if translated: words.extend(translated.split()) else: words.append(raw_word) return " ".join(words) def normalize_sql_question_to_english(message, schema=""): raw_message = (message or "").strip() normalized_message = normalize_text(raw_message) if not normalized_message: return LANGUAGE_FALLBACK_QUESTION schema_columns = _schema_column_names(schema) subject = _table_subject(schema) if _has_marker(normalized_message, ("mais caro", "mas caro", "plus cher", "maior preco", "highest price", "most expensive")): return _clean_english_question(f"what is the most expensive {subject}") if _has_marker(normalized_message, ("mais barato", "mas barato", "menor preco", "lowest price", "cheapest")): return _clean_english_question(f"what is the cheapest {subject}") if _has_marker(normalized_message, ("media", "moyenne", "promedio", "average")): metric = _first_numeric_focus(normalized_message, schema_columns) dimension = _dimension_after_group_marker(normalized_message, schema_columns) if dimension: return _clean_english_question(f"what is the average {metric} by {dimension}") return _clean_english_question(f"what is the average {metric}") if _has_marker(normalized_message, ("quantos", "cuantos", "combien", "count", "how many")): dimension = _dimension_after_group_marker(normalized_message, schema_columns) if dimension: return _clean_english_question(f"count {subject}s by {dimension}") return _clean_english_question(f"how many {subject}s are there") translated = _lexical_translate_sql_question(normalized_message) has_foreign_marker = any( re.search(rf"(?" if comparison_match.group(2) == "greater than" else "<" return f"SELECT * FROM {table_name} WHERE {actual_column} {operator} {comparison_match.group(3)};" ranking_column = _mentioned_schema_column(normalized_message, schema_columns) if ranking_column and _canonical_column(ranking_column) in NUMERIC_COLUMN_CANONICALS: if ( _has_marker(normalized_message, ("maior", "maximo", "mais alto")) or _has_marker(normalized_english, ("highest", "maximum", "max")) ) and "maior que" not in normalized_message: return f"SELECT * FROM {table_name} ORDER BY {ranking_column} DESC LIMIT 1;" if ( _has_marker(normalized_message, ("menor", "minimo", "mais baixo")) or _has_marker(normalized_english, ("lowest", "minimum", "min")) ) and "menor que" not in normalized_message: return f"SELECT * FROM {table_name} ORDER BY {ranking_column} ASC LIMIT 1;" if normalized_english in {"list all", "show all", "select all rows", "list all rows", "show all rows"}: return f"SELECT * FROM {table_name};" return "" def create_table_from_schema(schema): table_name, columns = parse_create_table_schema(schema) return format_create_table(table_name, columns) def extract_create_table_statement(text): cleaned = extract_sql_candidate(text) match = re.search( r"\bCREATE\s+TABLE\s+[A-Za-z_][\w]*\s*\(.*?\)\s*;?", cleaned, flags=re.IGNORECASE | re.DOTALL, ) return clean_generation(match.group(0)) if match else "" def last_create_table_from_history(chat_history): for item in reversed(list(chat_history or [])): if not isinstance(item, dict) or item.get("role") != "assistant": continue statement = extract_create_table_statement(item.get("content", "")) if statement: return statement return "" def extract_added_columns(message): message = (message or "").strip() patterns = ( r":\s*(.+)$", r"\b(?:add|include|with|adicionar|adicione|adicionando|inclua|incluir|acrescente|ter|coloque|colocar)\b\s+(?:um\s+|uma\s+|a\s+|an\s+)?(?:novo\s+|nova\s+|new\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", ) for pattern in patterns: match = re.search(pattern, message, flags=re.IGNORECASE) if not match: continue columns = [ parsed for parsed in (parse_column_definition(column) for column in split_column_list(match.group(1))) if parsed ] if columns: return columns return [] def extract_removed_columns(message): message = (message or "").strip() patterns = ( r"\b(?:remove|delete|drop|remova|remover|deletar|exclua|excluir|exclui|tira|tirar|tire|retira|retirar|retire|apaga|apagar|apague)\b\s+(?:a\s+|o\s+|the\s+)?(?:column|field|element|coluna|campo|elemento|item)?\s*(.+)$", ) for pattern in patterns: match = re.search(pattern, message, flags=re.IGNORECASE) if not match: continue columns = [normalize_identifier(column) for column in split_column_list(match.group(1))] columns = [column for column in columns if column] if columns: return columns return [] def extract_renamed_columns(message): pattern = ( r"\b(?:rename|edit|change|renomeie|renomear|renomeia|altere|alterar|altera|mude|mudar|muda|edita|editar)\s+" r"(\w+)\s+(?:to|para|as|como|por)\s+(\w+)" ) matches = re.findall(pattern, message or "", flags=re.IGNORECASE) troca_matches = re.findall(r"\btroca\b\s+(\w+)\s+\bpor\b\s+(\w+)", message or "", flags=re.IGNORECASE) invalid_old_names = {"ela", "ele", "isso", "isto", "essa", "esse", "this", "it"} invalid_new_names = {"ter", "have", "having", "tambem", "tambem"} return [ (normalize_identifier(old), normalize_identifier(new)) for old, new in [*matches, *troca_matches] if normalize_identifier(old) and normalize_identifier(new) and normalize_identifier(old) not in invalid_old_names and normalize_identifier(new) not in invalid_new_names ] def parse_compound_edit(message): segment_pattern = ( r"\s+(?:and|e)\s+" r"(?=\b(?:add|include|remove|delete|drop|rename|edit|change|" r"adicione|adicionar|adicionando|inclua|incluir|acrescente|" r"coloca|coloque|bota|insira|insere|" r"remova|remover|deletar|exclua|excluir|exclui|" r"tira|tirar|tire|retira|retire|apaga|apague|" r"renomeie|renomear|renomeia|altere|alterar|altera|" r"mude|mudar|muda|edita|editar|troca|trocar)\b)" ) segments = re.split(segment_pattern, message or "", flags=re.IGNORECASE) added, removed, renamed = [], [], [] for seg in segments: seg = seg.strip() if not seg: continue if is_rename_intent(seg): renamed.extend(extract_renamed_columns(seg)) elif re.search(r"\b(remove|delete|drop|remova|remover|deletar|exclua|excluir)\b", seg, flags=re.IGNORECASE): removed.extend(extract_removed_columns(seg)) else: cols = extract_added_columns(seg) if cols: added.extend(cols) return added, removed, renamed def edit_create_table_from_message(message, chat_history, active_schema): if not is_table_edit_intent(message) and not is_rename_intent(message): return "" base_sql = last_create_table_from_history(chat_history) or create_table_from_schema(active_schema) table_name, existing_columns = parse_create_table_schema(base_sql) if not table_name: return "" added_columns, removed_columns_list, renamed_columns = parse_compound_edit(message) removed_set = set(extract_removed_columns(message)) | {r for r in removed_columns_list} if not added_columns and not removed_set and not renamed_columns: return "" rename_map = dict(renamed_columns) kept_columns = [ (rename_map.get(col_name, col_name), col_type) for col_name, col_type in existing_columns if col_name not in removed_set ] updated_columns = [*kept_columns, *added_columns] if updated_columns == existing_columns: return "" return format_create_table(table_name, updated_columns) def create_table_from_suggestion(suggestion): if not suggestion: return "" if isinstance(suggestion, dict): table_name = suggestion.get("table_name") columns = [ (column.get("name"), column.get("type", "TEXT")) for column in suggestion.get("columns", []) if isinstance(column, dict) ] else: table_name = getattr(suggestion, "table_name", "") columns = getattr(suggestion, "columns", ()) parsed = [] for name, column_type in columns: identifier = normalize_identifier(name) if identifier: parsed.append((identifier, (column_type or "TEXT").upper())) return format_create_table(normalize_identifier(table_name), parsed)