| import os |
| import torch |
| import uuid |
| import requests |
| import streamlit as st |
| from streamlit.logger import get_logger |
| from auto_gptq import AutoGPTQForCausalLM |
| from langchain import HuggingFacePipeline, PromptTemplate |
| from langchain.chains import RetrievalQA |
| from langchain.document_loaders import PyPDFDirectoryLoader |
| from langchain.embeddings import HuggingFaceInstructEmbeddings |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from pdf2image import convert_from_path |
| from transformers import AutoTokenizer, TextStreamer, pipeline |
| from langchain.memory import ConversationBufferMemory |
| from gtts import gTTS |
| from io import BytesIO |
| from langchain.chains import ConversationalRetrievalChain |
| import streamlit.components.v1 as components |
| from langchain.document_loaders import UnstructuredMarkdownLoader |
| from langchain.vectorstores.utils import filter_complex_metadata |
| import fitz |
| from PIL import Image |
| from langchain.vectorstores import FAISS |
| import transformers |
|
|
| user_session_id = uuid.uuid4() |
|
|
| logger = get_logger(__name__) |
| st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", ) |
| st.session_state.disabled = False |
| st.title("Document QA by Dono") |
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| @st.cache_data |
| def load_data(): |
| loader = PyPDFDirectoryLoader("/home/user/app/pdfs/") |
| docs = loader.load() |
| return docs |
|
|
| @st.cache_resource |
| def load_model(_docs): |
| embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE}) |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256) |
| texts = text_splitter.split_documents(docs) |
| db = FAISS.from_documents(texts, embeddings) |
| model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/" |
| model_basename = "model" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
| model = AutoGPTQForCausalLM.from_quantized( |
| model_name_or_path, |
| revision="gptq-8bit-128g-actorder_False", |
| model_basename=model_basename, |
| use_safetensors=True, |
| trust_remote_code=True, |
| inject_fused_attention=False, |
| device=DEVICE, |
| quantize_config=None, |
| ) |
|
|
| DEFAULT_SYSTEM_PROMPT = """ |
| You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. |
| Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. |
| Please ensure that your responses are socially unbiased and positive in nature. |
| Always provide the citation for the answer from the text. |
| Try to include any section or subsection present in the text responsible for the answer. |
| Provide reference. Provide page number, section, sub section etc. |
| If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. |
| Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time. |
| The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future. |
| Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply. |
| The current date is 14 September, 2023. Try to extract information which is closer to this date. |
| Take a deep breath and work on this problem step-by-step. |
| """.strip() |
|
|
|
|
| def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: |
| return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip() |
|
|
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| text_pipeline = pipeline("text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| max_new_tokens=1024, |
| temperature=0.2, |
| top_p=0.95, |
| repetition_penalty=1.15, |
| streamer=streamer,) |
| llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.2}) |
|
|
| SYSTEM_PROMPT = ("Use the following pieces of context to answer the question at the end. " |
| "If you don't know the answer, just say that you don't know, " |
| "don't try to make up an answer.") |
|
|
| template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
| qa_chain = RetrievalQA.from_chain_type( |
| llm=llm, |
| chain_type="stuff", |
| retriever=db.as_retriever(search_kwargs={"k": 5}), |
| return_source_documents=True, |
| chain_type_kwargs={"prompt": prompt, |
| "verbose": False}) |
|
|
| print('load done') |
| return qa_chain |
|
|
|
|
| model_name_or_path = "Llama-2-13B-chat-GPTQ" |
| model_basename = "model" |
|
|
| st.session_state["llm_model"] = model_name_or_path |
|
|
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
|
|
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.markdown(message["content"]) |
|
|
|
|
| def on_select(): |
| st.session_state.disabled = True |
|
|
|
|
| def get_message_history(): |
| for message in st.session_state.messages: |
| role, content = message["role"], message["content"] |
| yield f"{role.title()}: {content}" |
|
|
|
|
| docs = load_data() |
| qa_chain = load_model(docs) |
|
|
| if prompt := st.chat_input("How can I help you today?"): |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| with st.chat_message("user"): |
| st.markdown(prompt) |
| with st.chat_message("assistant"): |
| message_placeholder = st.empty() |
| full_response = "" |
| message_history = "\n".join(list(get_message_history())[-3:]) |
| result = qa_chain(prompt) |
| output = [result['result']] |
|
|
| def generate_pdf(): |
| generate_audio() |
| page_number = int(result['source_documents'][0].metadata['page']) |
| doc = fitz.open(str(result['source_documents'][0].metadata['source'])) |
| text = str(result['source_documents'][0].page_content) |
| if text != '': |
| for page in doc: |
| text_instances = page.search_for(text) |
| for inst in text_instances: |
| highlight = page.add_highlight_annot(inst) |
| highlight.update() |
| doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True) |
|
|
| def pdf_page_to_image(pdf_file, page_number, output_image): |
| pdf_document = fitz.open(pdf_file) |
| page = pdf_document[page_number] |
| dpi = 300 |
| pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100)) |
| pix.save(output_image, "png") |
| pdf_document.close() |
| pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png') |
| image = Image.open('/home/user/app/pdf2image/output.png') |
| st.image(image) |
|
|
| def generate_audio(): |
| sound_file = BytesIO() |
| tts = gTTS(result['result'], lang='en') |
| tts.write_to_fp(sound_file) |
| st.audio(sound_file) |
| st.session_state.sound_played = True |
|
|
|
|
| for item in output: |
| full_response += item |
| message_placeholder.markdown(full_response + "▌") |
| message_placeholder.markdown(full_response) |
|
|
| st.button('Reference',on_click=generate_pdf) |
| st.button(':speaker:',on_click=generate_audio) |
|
|
|
|
|
|
|
|
| st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|