Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import PromptTemplate | |
| load_dotenv() | |
| import os | |
| from typing import List, Tuple | |
| import numpy as np | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.schema import Document | |
| from data import load_db | |
| from names import DATASET_ID, MODEL_ID | |
| from storage import RedisStorage, UserInput | |
| from utils import weighted_random_sample | |
| class RetrievalType: | |
| FIRST_MATCH = "first-match" | |
| POOL_MATCHES = "pool-matches" | |
| Matches = List[Tuple[Document, float]] | |
| USE_STORAGE = os.environ.get("USE_STORAGE", "True").lower() in ("true", "t", "1") | |
| print("USE_STORAGE", USE_STORAGE) | |
| def init(): | |
| embeddings = OpenAIEmbeddings(model=MODEL_ID) | |
| dataset_path = f"hub://{os.environ['ACTIVELOOP_ORG_ID']}/{DATASET_ID}" | |
| db = load_db( | |
| dataset_path, | |
| embedding_function=embeddings, | |
| token=os.environ["ACTIVELOOP_TOKEN"], | |
| # org_id=os.environ["ACTIVELOOP_ORG_ID"], | |
| read_only=True, | |
| ) | |
| storage = RedisStorage( | |
| host=os.environ["UPSTASH_URL"], password=os.environ["UPSTASH_PASSWORD"] | |
| ) | |
| prompt = PromptTemplate( | |
| input_variables=["user_input"], | |
| template=Path("prompts/bot.prompt").read_text(), | |
| ) | |
| llm = ChatOpenAI(temperature=0.3) | |
| chain = LLMChain(llm=llm, prompt=prompt) | |
| return db, storage, chain | |
| # Don't show the setting sidebar | |
| if "sidebar_state" not in st.session_state: | |
| st.session_state.sidebar_state = "collapsed" | |
| st.set_page_config(initial_sidebar_state=st.session_state.sidebar_state) | |
| db, storage, chain = init() | |
| st.title("FairytaleDJ ๐ต๐ฐ๐ฎ") | |
| st.markdown( | |
| """ | |
| *<small>Made with [DeepLake](https://www.deeplake.ai/) ๐ and [LangChain](https://python.langchain.com/en/latest/index.html) ๐ฆโ๏ธ</small>* | |
| ๐ซ Unleash the magic within you with our enchanting app, turning your sentiments into a Disney soundtrack! ๐ Just express your emotions, and embark on a whimsical journey as we tailor a Disney melody to match your mood. ๐๐""", | |
| unsafe_allow_html=True, | |
| ) | |
| how_it_works = st.expander(label="How it works") | |
| text_input = st.text_input( | |
| label="How are you feeling today?", | |
| placeholder="I am ready to rock and rool!", | |
| ) | |
| run_btn = st.button("Make me sing! ๐ถ") | |
| with how_it_works: | |
| st.markdown( | |
| """ | |
| The application follows a sequence of steps to deliver Disney songs matching the user's emotions: | |
| - **User Input**: The application starts by collecting user's emotional state through a text input. | |
| - **Emotion Encoding**: The user-provided emotions are then fed to a Language Model (LLM). The LLM interprets and encodes these emotions. | |
| - **Similarity Search**: These encoded emotions are utilized to perform a similarity search within our [vector database](https://www.deeplake.ai/). This database houses Disney songs, each represented as emotional embeddings. | |
| - **Song Selection**: From the pool of top matching songs, the application randomly selects one. The selection is weighted, giving preference to songs with higher similarity scores. | |
| - **Song Retrieval**: The selected song's embedded player is displayed on the webpage for the user. Additionally, the LLM interpreted emotional state associated with the chosen song is displayed. | |
| """ | |
| ) | |
| placeholder_emotions = st.empty() | |
| placeholder = st.empty() | |
| with st.sidebar: | |
| st.text("App settings") | |
| filter_threshold = st.slider( | |
| "Threshold used to filter out low scoring songs", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.8, | |
| ) | |
| max_number_of_songs = st.slider( | |
| "Max number of songs we will retrieve from the db", | |
| min_value=5, | |
| max_value=50, | |
| value=20, | |
| step=1, | |
| ) | |
| number_of_displayed_songs = st.slider( | |
| "Number of displayed songs", min_value=1, max_value=4, value=2, step=1 | |
| ) | |
| def filter_scores(matches: Matches, th: float = 0.8) -> Matches: | |
| return [(doc, score) for (doc, score) in matches if score > th] | |
| def normalize_scores_by_sum(matches: Matches) -> Matches: | |
| scores = [score for _, score in matches] | |
| tot = sum(scores) | |
| return [(doc, (score / tot)) for doc, score in matches] | |
| def get_song(user_input: str, k: int = 20): | |
| emotions = chain.run(user_input=user_input) | |
| matches = db.similarity_search_with_score(emotions, distance_metric="cos", k=k) | |
| # [print(doc.metadata['name'], score) for doc, score in matches] | |
| docs, scores = zip( | |
| *normalize_scores_by_sum(filter_scores(matches, filter_threshold)) | |
| ) | |
| choosen_docs = weighted_random_sample( | |
| np.array(docs), np.array(scores), n=number_of_displayed_songs | |
| ).tolist() | |
| return choosen_docs, emotions | |
| def set_song(user_input): | |
| if user_input == "": | |
| return | |
| # take first 120 chars | |
| user_input = user_input[:120] | |
| docs, emotions = get_song(user_input, k=max_number_of_songs) | |
| print(docs) | |
| songs = [] | |
| with placeholder_emotions: | |
| st.markdown("Your emotions: `" + emotions + "`") | |
| with placeholder: | |
| iframes_html = "" | |
| for doc in docs: | |
| name = doc.metadata["name"] | |
| print(f"song = {name}") | |
| songs.append(name) | |
| embed_url = doc.metadata["embed_url"] | |
| iframes_html += ( | |
| f'<iframe src="{embed_url}" style="border:0;height:100px"> </iframe>' | |
| ) | |
| st.markdown( | |
| f"<div style='display:flex;flex-direction:column'>{iframes_html}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| if USE_STORAGE: | |
| success_storage = storage.store( | |
| UserInput(text=user_input, emotions=emotions, songs=songs) | |
| ) | |
| if not success_storage: | |
| print("[ERROR] was not able to store user_input") | |
| if run_btn: | |
| set_song(text_input) | |