Spaces:
Sleeping
Sleeping
| from semantic import SemanticSearch | |
| import json | |
| import argparse | |
| import os | |
| from embeddings import get_embedding | |
| from openai import OpenAI | |
| import configparser | |
| class BackEnd: | |
| def __init__(self, config): | |
| self.model = "gpt-4.1" #config['ChatBot']['model'] | |
| self.client = OpenAI() | |
| self.semantic_search = SemanticSearch()#config['Semantic Search']) | |
| try: | |
| with open('data/demo_prompt.json') as json_file:#config['ChatBot']['prompt file']) as json_file: | |
| prompts = json.load(json_file) | |
| except: | |
| print(f"ERROR. Couldn't load prompt file {config['ChatBot']['prompt file']} or wrong json format") | |
| quit() | |
| lang = 'fr'#config['General']['language'][:2].lower() | |
| if lang == 'fr': | |
| self.prompt_template = prompts["French"] | |
| elif lang == 'en': | |
| self.prompt_template = prompts["English"] | |
| def process_query(self, query): | |
| query_embeddings = get_embedding(query) | |
| context = self.semantic_search.search(query_embeddings) | |
| for index, document in enumerate(context): | |
| context[index] = 'Document ' + str(index + 1) + '\n\n' + document | |
| print('context = ', context) | |
| documents = '\n\n'.join(context) | |
| prompt = self.prompt_template['system_prompt'] | |
| demo_prefix = self.prompt_template['demo_prefix'].format(query = query, context = context) | |
| prompt += demo_prefix + '\n' + documents + '\n\n' | |
| demo_postfix = self.prompt_template['demo_postfix'] | |
| prompt += demo_postfix | |
| if 'gpt' in self.model: | |
| response = self.client.responses.create( | |
| model = self.model , | |
| input= prompt) | |
| return json.loads(response.output_text), context | |
| # def main(): | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument('--config_file', type=str, required=True, help='File containing the configuration for the backend (in .ini format)') | |
| # parser.add_argument('--query', type=str, required=False, help='Test query for testing the system') | |
| # args = parser.parse_args() | |
| # config = configparser.ConfigParser() | |
| # config.read(args.config_file) | |
| # backend = BackEnd(config) | |
| # if args.query: | |
| # response = backend.process_query(args.query) | |
| # print(response) | |
| # if __name__ == '__main__': | |
| # main() | |