Article
· 17 hr il y a 30m de lecture

Conversion de texte en IRIS SQL à l'aide de LangChain

Une expérience sur la manière d'utiliser le cadre LangChain, la recherche vectorielle IRIS et les LLM pour générer une base de données SQL compatible IRIS à partir des invites utilisateur.

Cet article a été rédigé à partir du carnet suivant. Vous pouvez l'utiliser dans un environnement prêt à l'emploi avec l'application suivante dans OpenExchange.

Configuration

Tout d'abord, nous devons installer les bibliothèques nécessaires:

!pip install --upgrade --quiet langchain langchain-openai langchain-iris pandas

Ensuite, nous importons les modules requis et configurons l'environnement:

import os
import datetime
import hashlib
from copy import deepcopy
from sqlalchemy import create_engine
import getpass
import pandas as pd
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache
from langchain_iris import IRISVector

Nous utiliserons SQLiteCache pour mettre en cache les appels LLM:

# Cache pour les appels LLM
set_llm_cache(SQLiteCache(database_path=".langchain.db"))

Configurez les paramètres de connexion à la base de données IRIS:

# Paramètres de connexion à la base de données IRIS
os.environ["ISC_LOCAL_SQL_HOSTNAME"] = "localhost"
os.environ["ISC_LOCAL_SQL_PORT"] = "1972"
os.environ["ISC_LOCAL_SQL_NAMESPACE"] = "IRISAPP"
os.environ["ISC_LOCAL_SQL_USER"] = "_system"
os.environ["ISC_LOCAL_SQL_PWD"] = "SYS"

Si la clé API OpenAI n'est pas déjà configurée dans l'environnement, demandez à l'utilisateur de la saisir:

if not "OPENAI_API_KEY" in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass()

Créez la chaîne de connexion pour la base de données IRIS:

# Chaîne de connexion à la base de données IRIS
args = {
    'hostname': os.getenv("ISC_LOCAL_SQL_HOSTNAME"), 
    'port': os.getenv("ISC_LOCAL_SQL_PORT"), 
    'namespace': os.getenv("ISC_LOCAL_SQL_NAMESPACE"), 
    'username': os.getenv("ISC_LOCAL_SQL_USER"), 
    'password': os.getenv("ISC_LOCAL_SQL_PWD")
}
iris_conn_str = f"iris://{args['username']}:{args['password']}@{args['hostname']}:{args['port']}/{args['namespace']}"

Etablissez la connexion avec la base de données IRIS:

# Connexion à la base de données IRIS
engine = create_engine(iris_conn_str)
cnx = engine.connect().connection

Préparez un dictionnaire contenant les renseignements contextuels pour l'invite du système:

# Dict pour les renseignements contextuels de l'invite système
context = {}
context["top_k"] = 3

Création de l'invite

Pour transformer les données de l'utilisateur en requêtes SQL compatibles avec la base de données IRIS, nous devons créer une invite efficace pour le modèle linguistique. Nous commençons par une invite initiale qui fournit des instructions de base pour générer des requêtes SQL. Ce modèle est dérivé des Invites par défaut de LangChain pour MSSQL et personnalisé pour la base de données IRIS.

# Modèle d'invite de base avec instructions SQL pour la base de données IRIS
iris_sql_template = """
Vous êtes un expert InterSystems IRIS.  Compte tenu d'une question d'entrée, créez d'abord une requête InterSystems IRIS syntaxiquement correcte pour exécuter et renvoyer la réponse à la question saisie.
Si l'utilisateur ne spécifie pas dans la question un nombre spécifique d'exemples à obtenir, demandez au maximum {top_k} résultats en utilisant la clause TOP conformément à InterSystems IRIS. Vous pouvez classer les résultats de manière à obtenir les renseignements les plus pertinents de la base de données.
Ne faites jamais de requête pour toutes les colonnes d'une table. Vous ne devez requérir que les colonnes nécessaires pour répondre à la question.. Mettez chaque nom de colonne entre guillemets simples ('') pour indiquer qu'il s'agit d'identifiants délimités.
Veillez à n'utiliser que les noms de colonnes que vous pouvez voir dans les tables ci-dessous. Veillez à ne pas requérir les colonnes qui n'existent pas. Faites également attention à ce que les colonnes se trouvent dans les différents tables.
Veillez à utiliser la fonction CAST(CURRENT_DATE as date) pour obtenir la date du jour, si la question porte sur "aujourd'hui".
Utilisez des guillemets doubles pour délimiter les identifiants des colonnes.
Renvoyez des données SQL simples ; n'appliquez aucune forme de formatage.
"""

Cette invite de base configure le modèle linguistique (LLM) pour qu'il fonctionne comme un expert SQL avec des conseils spécifiques pour la base de données IRIS. Ensuite, nous fournissons une invite auxiliaire avec des renseignements sur le schéma de la base de données pour éviter les hallucinations.

# Extension des modèles SQL pour inclure les renseignements sur le contexte des tables
tables_prompt_template = """
N'utilisez que les tables suivantes:
{table_info}
"""

Afin d'améliorer la précision des réponses du LLM, nous utilisons une technique appelée "incitation en quelques coups" ("few-shot prompting"). Il s'agit de présenter quelques exemples au LLM.

# Extension du modèle SQL pour l'inclusion de quelques exemples
prompt_sql_few_shots_template = """
Vous trouverez ci-dessous un certain nombre d'exemples de questions et de requêtes SQL correspondantes.

{examples_value}
"""

Nous définissons le modèle pour des exemples en quelques coups:

# Modèle d'invite à quelques coups
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)

Nous construisons l'invite utilisateur en utilisant le modèle en quelques coups:

# Modèle d'invite utilisateur
user_prompt = "\n" + example_prompt.invoke({"input": "{input}", "query": ""}).to_string()

Enfin, nous composons toutes les invites pour créer l'invite finale:

# Modèle d'invite complet
prompt = (
    ChatPromptTemplate.from_messages([("system", iris_sql_template)])
    + ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
    + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
    + ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt

Cette invite attend les variables examples_value, input, table_info, et top_k.

Voici comment l'invite est structurée:

ChatPromptTemplate(
    input_variables=['examples_value', 'input', 'table_info', 'top_k'], 
    messages=[
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['top_k'], 
                template=iris_sql_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['table_info'], 
                template=tables_prompt_template
            )
        ), 
        SystemMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['examples_value'], 
                template=prompt_sql_few_shots_template
            )
        ), 
        HumanMessagePromptTemplate(
            prompt=PromptTemplate(
                input_variables=['input'], 
                template=user_prompt
            )
        )
    ]
)

Pour visualiser la manière dont l'invite sera envoyée au LLM, nous pouvons utiliser des valeurs de remplacement pour les variables requises:

prompt_value = prompt.invoke({
    "top_k": "<top_k>",
    "table_info": "<table_info>",
    "examples_value": "<examples_value>",
    "input": "<input>"
})
print(prompt_value.to_string())
Système: 
Vous êtes un expert d'InterSystems IRIS. Compte tenu d'une question d'entrée, créez d'abord une requête InterSystems IRIS syntaxiquement correcte pour exécuter et renvoyer la réponse à la question saisie.
Si l'utilisateur ne spécifie pas dans la question un nombre spécifique d'exemples à obtenir, demandez au maximum <top_k> résultats en utilisant la clause TOP conformément à InterSystems IRIS. Vous pouvez classer les résultats de manière à obtenir les renseignements les plus pertinents de la base de données.
Ne faites jamais de requête pour toutes les colonnes d'une table. Vous ne devez requérir que les colonnes nécessaires pour répondre à la question.. Mettez chaque nom de colonne entre guillemets simples ('') pour indiquer qu'il s'agit d'identifiants délimités.
Veillez à n'utiliser que les noms de colonnes que vous pouvez voir dans les tables ci-dessous. Veillez à ne pas requérir les colonnes qui n'existent pas. Faites également attention à ce que les colonnes se trouvent dans les différents tables.
Veillez à utiliser la fonction CAST(CURRENT_DATE as date) pour obtenir la date du jour, si la question porte sur "aujourd'hui".
Utilisez des guillemets doubles pour délimiter les identifiants des colonnes.
Renvoyez des données SQL simples; n'appliquez aucune forme de formatage.

Système: 
N'utilisez que les tables suivantes:
<table_info>

Système: 
Vous trouverez ci-dessous un certain nombre d'exemples de questions et de requêtes SQL correspondantes.

<examples_value>

Human: 
User input: <input>
SQL query: 

Maintenant, nous sommes prêts à envoyer cette invite au LLM en fournissant les variables nécessaires. Passons à l'étape suivante lorsque vous êtes prêt.

Fourniture des renseignements sur la table

Pour créer des requêtes SQL précises, nous devons fournir au modèle linguistique (LLM) des renseignements détaillés sur les tables de la base de données. Sans ces renseignements, le LLM pourrait générer des requêtes qui semblent plausibles mais qui sont incorrectes en raison d'hallucinations. Par conséquent, notre première étape consiste à créer une fonction qui récupère les définitions des tables de la base de données IRIS.

Fonction de récupération des définitions de tables

La fonction suivante interroge INFORMATION_SCHEMA pour obtenir les définitions de tables pour un schéma donné. Si une table spécifique est fournie, elle récupère la définition de cette table ; sinon, elle récupère les définitions de toutes les tables du schéma.

def get_table_definitions_array(cnx, schema, table=None):
    cursor = cnx.cursor()

    # Requête de base pour obtenir les renseignements sur les colonnes
    query = """
    SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA = %s
    """

    # Paramètres de la requête
    params = [schema]

    # Ajout de filtres optionnels
    if table:
        query += " AND TABLE_NAME = %s"
        params.append(table)

    # Exécution de la requête
    cursor.execute(query, params)

    # Récupération des résultats
    rows = cursor.fetchall()

    # Traitement des résultats pour générer la (les) définition(s) de table(s)
    table_definitions = {}
    for row in rows:
        table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row
        if table_name not in table_definitions:
            table_definitions[table_name] = []
        table_definitions[table_name].append({
            "column_name": column_name,
            "column_type": column_type,
            "is_nullable": is_nullable,
            "column_default": column_default,
            "column_key": column_key,
            "extra": extra
        })

    primary_keys = {}

    # Construire la chaîne de sortie
    result = []
    for table_name, columns in table_definitions.items():
        table_def = f"CREATE TABLE {schema}.{table_name} (\n"
        column_definitions = []
        for column in columns:
            column_def = f"  {column['column_name']} {column['column_type']}"
            if column['is_nullable'] == "NO":
                column_def += " NOT NULL"
            if column['column_default'] is not None:
                column_def += f" DEFAULT {column['column_default']}"
            if column['extra']:
                column_def += f" {column['extra']}"
            column_definitions.append(column_def)
        if table_name in primary_keys:
            pk_def = f"  PRIMARY KEY ({', '.join(primary_keys[table_name])})"
            column_definitions.append(pk_def)
        table_def += ",\n".join(column_definitions)
        table_def += "\n);"
        result.append(table_def)

    return result

Récupération des définitions de tables pour un schéma

Pour cet exemple, nous utilisons le schéma "Aviation", qui est disponible ici.

# Récupération des définitions de tables pour un schéma "Aviation"
tables = get_table_definitions_array(cnx, "Aviation")
print(tables)

Cette fonction renvoie les instructions CREATE TABLE (creer une table) pour toutes les tables du schéma "Aviation":

[
    'CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  AccidentExplosion varchar,\n  AccidentFire varchar,\n  AirFrameHours varchar,\n  AirFrameHoursSince varchar,\n  AirFrameHoursSinceLastInspection varchar,\n  AircraftCategory varchar,\n  AircraftCertMaxGrossWeight integer,\n  AircraftHomeBuilt varchar,\n  AircraftKey integer NOT NULL,\n  AircraftManufacturer varchar,\n  AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,\n  DestinationCountry varchar,\n  DestinationSameAsLocal varchar,\n  DestinationState varchar,\n  EngineCount integer,\n  EvacuationOccurred varchar,\n  EventId varchar NOT NULL,\n  FlightMedical varchar,\n  FlightMedicalType varchar,\n  FlightPhase integer,\n  FlightPlan varchar,\n  FlightPlanActivated varchar,\n  FlightSiteSeeing varchar,\n  FlightType varchar,\n  GearType varchar,\n  LastInspectionDate timestamp,\n  LastInspectionType varchar,\n  Missing varchar,\n  OperationDomestic varchar,\n  OperationScheduled varchar,\n  OperationType varchar,\n  OperatorCertificate varchar,\n  OperatorCertificateNum varchar,\n  OperatorCode varchar,\n  OperatorCountry varchar,\n  OperatorIndividual varchar,\n  OperatorName varchar,\n  OperatorState varchar,\n  Owner varchar,\n  OwnerCertified varchar,\n  OwnerCountry varchar,\n  OwnerState varchar,\n  RegistrationNumber varchar,\n  ReportedToICAO varchar,\n  SeatsCabinCrew integer,\n  SeatsFlightCrew integer,\n  SeatsPassengers integer,\n  SeatsTotal integer,\n  SecondPilot varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.EventC("Aircraft"))\n);',
    'CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  Age integer,\n  AircraftKey integer NOT NULL,\n  Category varchar,\n  CrewNumber integer NOT NULL,\n  EventId varchar NOT NULL,\n  Injury varchar,\n  MedicalCertification varchar,\n  MedicalCertificationDate timestamp,\n  MedicalCertificationValid varchar,\n  Seat varchar,\n  SeatbeltUsed varchar,\n  Sex varchar,\n  ShoulderHarnessUsed varchar,\n  ToxicologyTestPerformed varchar,\n  childsub bigint NOT NULL DEFAULT $i(^Aviation.AircraftC("Crew"))\n);',
    'CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  AirportDirection integer,\n  AirportDistance varchar,\n  AirportElevation integer,\n  AirportLocation varchar,\n  AirportName varchar,\n  Altimeter varchar,\n  EventDate timestamp,\n  EventId varchar NOT NULL,\n  EventTime integer,\n  FAADistrictOffice varchar,\n  InjuriesGroundFatal integer,\n  InjuriesGroundMinor integer,\n  InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  InjuriesTotal integer,\n  InjuriesTotalFatal integer,\n  InjuriesTotalMinor integer,\n  InjuriesTotalNone integer,\n  InjuriesTotalSerious integer,\n  InvestigatingAgency varchar,\n  LightConditions varchar,\n  LocationCity varchar,\n  LocationCoordsLatitude double,\n  LocationCoordsLongitude double,\n  LocationCountry varchar,\n  LocationSiteZipCode varchar,\n  LocationState varchar,\n  MidAir varchar,\n  NTSBId varchar,\n  NarrativeCause varchar,\n  NarrativeFull varchar,\n  NarrativeSummary varchar,\n  OnGroundCollision varchar,\n  SkyConditionCeiling varchar,\n  SkyConditionCeilingHeight integer,\n  SkyConditionNonCeiling varchar,\n  SkyConditionNonCeilingHeight integer,\n  TimeZone varchar,\n  Type varchar,\n  Visibility varchar,\n  WeatherAirTemperature integer,\n  WeatherPrecipitation varchar,\n  WindDirection integer,\n  WindDirectionIndicator varchar,\n  WindGust integer,\n  WindGustIndicator varchar,\n  WindVelocity integer,\n  WindVelocityIndicator varchar\n);'
]

Avec ces définitions de tables, nous pouvons passer à l'étape suivante, qui consiste à les intégrer dans notre invite pour le LLM. Cela permet de s'assurer que le LLM a des renseignements précis et complets sur le schéma de la base de données lorsqu'il génère des requêtes SQL.

Sélection des tables les plus pertinentes

Lorsque vous travaillez avec des bases de données, en particulier les plus grandes, l'envoi du langage de définition des données (DDL) pour toutes les tables d'une invite peut s'avérer peu pratique. Si cette approche peut fonctionner pour les petites bases de données, les bases de données réelles contiennent souvent des centaines ou des milliers de tables, ce qui rend inefficace le traitement de chacune d'entre elles.

De plus, il est peu probable qu'un modèle linguistique ait besoin de connaître toutes les tables de la base de données pour générer efficacement des requêtes SQL. Pour relever ce défi, nous pouvons exploiter les capacités de recherche sémantique pour sélectionner uniquement les tables les plus pertinentes en fonction de la requête de l'utilisateu.

Approche

Nous y parvenons en utilisant la recherche sémantique avec IRIS Vector Search. Notez que cette méthode est plus efficace si les identifiants de vos éléments SQL (tels que les tables, les champs et les clés) ont des noms significatifs. Si vos identifiants sont des codes arbitraires, envisagez plutôt d'utiliser un dictionnaire de données.

Étapes

  1. Récupération des renseignements sur les tables

Commencez par extraire les définitions des tables dans d'un objet DataFrame pandas:

# Récupérer les définitions de tables dans un objet DataFrame pandas
table_def = get_table_definitions_array(cnx=cnx, schema='Aviation')
table_df = pd.DataFrame(data=table_def, columns=["col_def"])
table_df["id"] = table_df.index + 1
table_df

L'objet DataFrame (table_df) ressemblera à ceci:

col_def id
0 CREATE TABLE Aviation.Aircraft (\n Event bigi... 1
1 CREATE TABLE Aviation.Crew (\n Aircraft varch... 2
2 CREATE TABLE Aviation.Event (\n ID bigint NOT... 3
  1. Répartition des définitions dans des documents

Ensuite, répartissez les définitions des tables dans les documents Langchain. . Cette étape est cruciale pour gérer de gros fragment de texte et extraire des incorporations de texte:

loader = DataFrameLoader(table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
tables_docs = text_splitter.split_documents(documents)
tables_docs

La liste tables_docs qui en résulte contient des documents fractionnés avec des métadonnées, comme suit:

[Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 2}, page_content='CREATE TABLE Aviation.Crew (\n  Aircraft varchar NOT NULL,\n  ID varchar NOT NULL,\n  ...'),
 Document(metadata={'id': 3}, page_content='CREATE TABLE Aviation.Event (\n  ID bigint NOT NULL DEFAULT $i(^Aviation.EventD),\n  ...')]
  1. Extraction des incorporations et stockage dans IRIS

Utilisez maintenant la classe IRISVector de langchain-iris pour extraire les vecteurs d'intégration et les stocker:

tables_vector_store = IRISVector.from_documents(
    embedding=OpenAIEmbeddings(), 
    documents=tables_docs,
    connection_string=iris_conn_str,
    collection_name="sql_tables",
    pre_delete_collection=True
)

Remarque : l'indicateur pre_delete_collection est fixé à True (vrai) à des fins de démonstration, afin de garantir une nouvelle collection à chaque exécution du test. Dans un environnement de production, cet indicateur doit généralement être défini sur False (faux).

  1. Recherche de documents pertinents
    Avec les incorporations de table stockées, vous pouvez désormais interroger les tables pertinentes en fonction des données de l'utilisateur:
input_query = "List the first 2 manufacturers"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

Par exemple, une requête portant sur les fabricants peut aboutir à un résultat:

[Document(metadata={'id': 1}, page_content='GearType varchar,\n  LastInspectionDate timestamp,\n  ...'),
 Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...')]

À partir des métadonnées, vous pouvez voir que seul la table ID 1 (Aviation.Avion) est pertinente, ce qui correspond à la requête.

  1. Gestion des cas limites

Bien que cette approche soit généralement efficace, elle n'est pas toujours parfaite. Par exemple, la recherche de sites d'accidents peut également renvoyer des tables moins pertinentes:

input_query = "List the top 10 most crash sites"
relevant_tables_docs = tables_vector_store.similarity_search(input_query, k=3)
relevant_tables_docs

Les résultats peuvent inclure ce qui suit:

[Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  ...'),
 Document(metadata={'id': 3}, page_content='InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  ...'),
 Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  ...')]

Bien que la table Aviation.Event ait été récupérée deux fois, la table Aviation.Aircraft peut également apparaître, ce qui pourrait être amélioré par un filtrage ou un seuillage supplémentaire. Cela dépasse le cadre de cet exemple et sera laissé à l'appréciation de futures implémentations.

  1. Définition d'une fonction pour récupérer les tables pertinentes

Pour automatiser ce processus, définissez une fonction qui filtre et renvoie les tableaux pertinents en fonction des données de l'utilisateur:

def get_relevant_tables(user_input, tables_vector_store, table_df):
    relevant_tables_docs = tables_vector_store.similarity_search(user_input)
    relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
    indices = table_df["id"].isin(relevant_tables_docs_indices)
    relevant_tables_array = [x for x in table_df[indices]["col_def"]]
    return relevant_tables_array

Cette fonction permet d'extraire efficacement les tables pertinentes à envoyer au LLM, de réduire la longueur de l'invite et d'améliorer les performances globales de la requête.

Sélection des exemples les plus pertinents ("invitation en quelques coups")

Lorsque vous travaillez avec des modèles linguistiques (LLM), le fait de leur fournir des exemples pertinents permet d'obtenir des réponses précises et adaptées au contexte. Ces exemples, appelés "quelques exemples", guident le LLM dans la compréhension de la structure et du contexte des requêtes qu'il doit traiter.

Dans notre cas, nous devons remplir la variable examples_value avec un ensemble varié de requêtes SQL qui couvrent un large spectre de la syntaxe SQL d'IRIS et des tables disponibles dans la base de données. Cela permet d'éviter que le LLM ne génère des requêtes incorrectes ou non pertinentes.

Définition de requêtes d'exemple

Vous trouverez ci-dessous une liste d'exemples de requêtes conçues pour illustrer diverses opérations SQL:

examples = [
    {"input": "List all aircrafts.", "query": "SELECT * FROM Aviation.Aircraft"},
    {"input": "Find all incidents for the aircraft with ID 'N12345'.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
    {"input": "List all incidents in the 'Commercial' operation type.", "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"},
    {"input": "Find the total number of incidents.", "query": "SELECT COUNT(*) FROM Aviation.Event"},
    {"input": "List all incidents that occurred in 'Canada'.", "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"},
    {"input": "How many incidents are associated with the aircraft with AircraftKey 5?", "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"},
    {"input": "Find the total number of distinct aircrafts involved in incidents.", "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"},
    {"input": "List all incidents that occurred after 5 PM.", "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"},
    {"input": "Who are the top 5 operators by the number of incidents?", "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"},
    {"input": "Which incidents occurred in the year 2020?", "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"},
    {"input": "What was the month with most events in the year 2020?", "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"},
    {"input": "How many crew members were involved in incidents?", "query": "SELECT COUNT(*) FROM Aviation.Crew"},
    {"input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"},
    {"input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"},
    {"input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", "query": "SELECT c.CrewNumber AS 'Crew Number', c.Age, c.Sex AS Gender, e.EventDate AS 'Event Date', e.LocationCity AS 'Location City', e.LocationState AS 'Location State' FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"}
]

Sélection d'exemples pertinents

Compte tenu de la liste d'exemples qui ne cesse de s'allonger, il n'est pas pratique de fournir tous les exemples au LLM. Au lieu de cela, nous utilisons la recherche vectorielle d'IRIS et la classe SemanticSimilarityExampleSelector pour identifier les exemples les plus pertinents sur la base des invites de l'utilisateur.

Définition du Sélecteur d'exemples:

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    IRISVector,
    k=5,
    input_keys=["input"],
    connection_string=iris_conn_str,
    collection_name="sql_samples",
    pre_delete_collection=True
)

Remarque : l'indicateur pre_delete_collection est fixé ici à des fins de démonstration, afin de garantir une nouvelle collection à chaque exécution du test. Dans un environnement de production, cet indicateur doit être défini sur Faux (false) pour éviter les suppressions inutiles.

Requéte du sélecteur:

Pour trouver les exemples les plus pertinents pour une saisie donnée, utilisez le sélecteur comme suit:

input_query = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
relevant_examples = example_selector.select_examples({"input": input_query})

Les résultats pourraient ressembler à ceci:

[{'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'},
 {'input': "Find all incidents for the aircraft with ID 'N12345'.", 'query': "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all aircrafts.', 'query': 'SELECT * FROM Aviation.Aircraft'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'}]

Si vous avez spécifiquement besoin d'exemples liés aux quantités, vous pouvez interroger le sélecteur en conséquence:

input_query = "What is the number of incidents involving Boeing aircraft."
quantity_examples = example_selector.select_examples({"input": input_query})

Le résultat peut être comme suit:

[{'input': 'How many incidents are associated with the aircraft with AircraftKey 5?', 'query': 'SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.', 'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'},
 {'input': 'How many crew members were involved in incidents?', 'query': 'SELECT COUNT(*) FROM Aviation.Crew'},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.', 'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.', 'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'}]

Ce résultat comprend des exemples qui traitent spécifiquement du comptage et des quantités.

Considérations futures

Bien que le sélecteur SemanticSimilarityExampleSelector soit puissant, il est important de noter que tous les exemples sélectionnés ne sont pas forcément parfaits. Les améliorations futures peuvent impliquer l'ajout de filtres ou de seuils pour exclure les résultats moins pertinents, garantissant que seuls les exemples les plus appropriés sont fournis au LLM.

Test de précision

Pour évaluer les performances de l'invite et de la génération de requêtes SQL, nous devons mettre en place et exécuter une série de tests. L'objectif est d'évaluer dans quelle mesure le LLM génère des requêtes SQL basées sur les données de l'utilisateur, avec et sans l'utilisation de quelques coups basés sur des exemples.

Fonction de génération de requêtes SQL

Nous commençons par définir une fonction qui utilise le LLM pour générer des requêtes SQL en fonction du contexte fourni, de l'invite, de la saisie de l'utilisateur et d'autres paramètres:

def get_sql_from_text(context, prompt, user_input, use_few_shots, tables_vector_store, table_df, example_selector=None, example_prompt=None):
    relevant_tables = get_relevant_tables(user_input, tables_vector_store, table_df)
    context["table_info"] = "\n\n".join(relevant_tables)

    examples = example_selector.select_examples({"input": user_input}) if example_selector else []
    context["examples_value"] = "\n\n".join([
        example_prompt.invoke(x).to_string() for x in examples
    ])

    model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    output_parser = StrOutputParser()
    chain_model = prompt | model | output_parser

    response = chain_model.invoke({
        "top_k": context["top_k"],
        "table_info": context["table_info"],
        "examples_value": context["examples_value"],
        "input": user_input
    })
    return response

Nous commençons par définir une fonction qui utilise le LLM pour générer des requêtes SQL en fonction du contexte fourni, de l'invite, de la saisie de l'utilisateur et d'autres paramètres

Testez l'invite avec et sans exemples:

# Exécution de l'invite **avec** quelques coups
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=True, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
    example_selector=example_selector, 
    example_prompt=example_prompt,
)
print(response_with_few_shots)
SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.EventId = a.EventId
WHERE Year(e.EventDate) = 2010
# Exécution de l'invite **sans** quelques coups
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_no_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=False, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
)
print(response_with_no_few_shots)
SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.ID = a.Event
WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
Utility Functions for Testing

Pour tester les requêtes SQL générées, nous définissons quelques fonctions utilitaires:

def execute_sql_query(cnx, query):
    try:
        cursor = cnx.cursor()
        cursor.execute(query)
        rows = cursor.fetchall()
        return rows
    except:
        print('Error running query:')
        print(query)
        print('-'*80)
    return None

def sql_result_equals(cnx, query, expected):
    rows = execute_sql_query(cnx, query)
    result = [set(row._asdict().values()) for row in rows or []]
    if result != expected and rows is not None:
        print('Result not as expected for query:')
        print(query)
        print('-'*80)
    return result == expected
# Test SQL pour l'invite **avec** quelques coups
print("SQL is OK" if not execute_sql_query(cnx, response_with_few_shots) is None else "SQL is not OK")
    SQL is OK
# Test SQL pour l'invite **sans** quelques coups
print("SQL is OK" if not execute_sql_query(cnx, response_with_no_few_shots) is None else "SQL is not OK")
    error on running query: 
    SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.ID = a.Event
    WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
    --------------------------------------------------------------------------------
    SQL is not OK

Définition et exécution des tests

Définissez un ensemble de scénarios de test et les exécutez:

tests = [{
    "input": "What were the top 3 years with the most recorded events?",
    "expected": [{128, 2003}, {122, 2007}, {117, 2005}]
},{
    "input": "How many incidents involving Boeing aircraft.",
    "expected": [{5}]
},{
    "input": "How many incidents that resulted in fatalities.",
    "expected": [{237}]
},{
    "input": "List event Id and date and, crew number, age and gender for incidents that occurred in 2013.",
    "expected": [{1, datetime.datetime(2013, 3, 4, 11, 6), '20130305X71252', 59, 'M'},
                 {1, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 32, 'M'},
                 {2, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 35, 'M'},
                 {1, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 25, 'M'},
                 {2, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 34, 'M'},
                 {1, datetime.datetime(2013, 2, 1, 15, 0), '20130203X53401', 29, 'M'},
                 {1, datetime.datetime(2013, 2, 15, 15, 0), '20130218X70747', 27, 'M'},
                 {1, datetime.datetime(2013, 3, 2, 15, 0), '20130303X21011', 49, 'M'},
                 {1, datetime.datetime(2013, 3, 23, 13, 52), '20130326X85150', 'M', None}]
},{
    "input": "Find the total number of incidents that occurred in the United States.",
    "expected": [{1178}]
},{
    "input": "List all incidents latitude and longitude coordinates with more than 5 injuries that occurred in 2010.",
    "expected": [{-78.76833333333333, 43.25277777777778}]
},{
    "input": "Find all incidents in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model.",
    "expected": [
        {datetime.datetime(2010, 5, 20, 13, 43), '20100520X60222', 'CIRRUS DESIGN CORP', 'Farmingdale', 'New York', 'SR22'},
        {datetime.datetime(2010, 4, 11, 15, 0), '20100411X73253', 'CZECH AIRCRAFT WORKS SPOL SRO', 'Millbrook', 'New York', 'SPORTCRUISER'},
        {'108', datetime.datetime(2010, 1, 9, 12, 55), '20100111X41106', 'Bayport', 'New York', 'STINSON'},
        {datetime.datetime(2010, 8, 1, 14, 20), '20100801X85218', 'A185F', 'CESSNA', 'New York', 'Newfane'}
    ]
}]

Évaluation de la précision

Exécutez les tests et calculez la précision:

def execute_tests(cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt):
    tests_generated_sql = [(x, get_sql_from_text(
            context, 
            prompt, 
            user_input=x['input'], 
            use_few_shots=use_few_shots, 
            tables_vector_store=tables_vector_store, 
            table_df=table_df,
            example_selector=example_selector if use_few_shots else None, 
            example_prompt=example_prompt if use_few_shots else None,
        )) for x in deepcopy(tests)]

    tests_sql_executions = [(x[0], sql_result_equals(cnx, x[1], x[0]['expected'])) 
                            for x in tests_generated_sql]

    accuracy = sum(1 for i in tests_sql_executions if i[1] == True) / len(tests_sql_executions)
    print(f'Accuracy: {accuracy}')
    print('-'*80)

Résultats

# Tests de précision pour les invites exécutées **sans** quelques coups
use_few_shots = False
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT "EventDate", COUNT("EventId") as "TotalEvents"
    FROM Aviation.Event
    GROUP BY "EventDate"
    ORDER BY "TotalEvents" DESC
    TOP 3;
    --------------------------------------------------------------------------------
    error on running query: 
    SELECT "EventId", "EventDate", "C"."CrewNumber", "C"."Age", "C"."Sex"
    FROM "Aviation.Event" AS "E"
    JOIN "Aviation.Crew" AS "C" ON "E"."ID" = "C"."EventId"
    WHERE "E"."EventDate" >= '2013-01-01' AND "E"."EventDate" < '2014-01-01'
    --------------------------------------------------------------------------------
    result not expected for query: 
    SELECT TOP 3 "e"."EventId", "e"."EventDate", "e"."LocationCity", "e"."LocationState", "a"."AircraftManufacturer", "a"."AircraftModel"
    FROM "Aviation"."Event" AS "e"
    JOIN "Aviation"."Aircraft" AS "a" ON "e"."ID" = "a"."Event"
    WHERE "e"."EventDate" >= '2010-01-01' AND "e"."EventDate" < '2011-01-01'
    --------------------------------------------------------------------------------
    accuracy: 0.5714285714285714
    --------------------------------------------------------------------------------
# Tests de précision pour les invites exécutées **avec** quelques coups
use_few_shots = True
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)
    error on running query: 
    SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
    FROM Aviation.Event e
    JOIN Aviation.Aircraft a ON e.EventId = a.EventId
    WHERE Year(e.EventDate) = 2010 TOP 3
    --------------------------------------------------------------------------------
    accuracy: 0.8571428571428571
    --------------------------------------------------------------------------------

Conclusion

La précision des requêtes SQL générées avec des exemples (quelques coups) est environ 49% plus élevée que celles générées sans exemples (85% contre 57%).

Références

Discussion (0)2
Connectez-vous ou inscrivez-vous pour continuer