RobustDialogueDemo / backend.py
DialogueRobust
First commit
e490ab5
from semantic import SemanticSearch
import json
import argparse
import os
from embeddings import get_embedding
from openai import OpenAI
import configparser
class BackEnd:
def __init__(self, config):
self.model = "gpt-4.1" #config['ChatBot']['model']
self.client = OpenAI()
self.semantic_search = SemanticSearch()#config['Semantic Search'])
try:
with open('data/demo_prompt.json') as json_file:#config['ChatBot']['prompt file']) as json_file:
prompts = json.load(json_file)
except:
print(f"ERROR. Couldn't load prompt file {config['ChatBot']['prompt file']} or wrong json format")
quit()
lang = 'fr'#config['General']['language'][:2].lower()
if lang == 'fr':
self.prompt_template = prompts["French"]
elif lang == 'en':
self.prompt_template = prompts["English"]
def process_query(self, query):
query_embeddings = get_embedding(query)
context = self.semantic_search.search(query_embeddings)
for index, document in enumerate(context):
context[index] = 'Document ' + str(index + 1) + '\n\n' + document
print('context = ', context)
documents = '\n\n'.join(context)
prompt = self.prompt_template['system_prompt']
demo_prefix = self.prompt_template['demo_prefix'].format(query = query, context = context)
prompt += demo_prefix + '\n' + documents + '\n\n'
demo_postfix = self.prompt_template['demo_postfix']
prompt += demo_postfix
if 'gpt' in self.model:
response = self.client.responses.create(
model = self.model ,
input= prompt)
return json.loads(response.output_text), context
# def main():
# parser = argparse.ArgumentParser()
# parser.add_argument('--config_file', type=str, required=True, help='File containing the configuration for the backend (in .ini format)')
# parser.add_argument('--query', type=str, required=False, help='Test query for testing the system')
# args = parser.parse_args()
# config = configparser.ConfigParser()
# config.read(args.config_file)
# backend = BackEnd(config)
# if args.query:
# response = backend.process_query(args.query)
# print(response)
# if __name__ == '__main__':
# main()