Spaces:
Build error
Build error
Upload 2 files
Browse files
app.py
CHANGED
|
@@ -10,7 +10,8 @@ from utils import (
|
|
| 10 |
clean_entities,
|
| 11 |
create_dense_embeddings,
|
| 12 |
create_sparse_embeddings,
|
| 13 |
-
|
|
|
|
| 14 |
format_query,
|
| 15 |
get_flan_alpaca_xl_model,
|
| 16 |
generate_alpaca_ner_prompt,
|
|
@@ -70,9 +71,12 @@ with col1:
|
|
| 70 |
if ner_choice == "Alpaca":
|
| 71 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
| 72 |
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
|
| 73 |
-
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
-
company_ent
|
|
|
|
| 76 |
|
| 77 |
ticker_index, quarter_index, year_index = clean_entities(
|
| 78 |
company_ent, quarter_ent, year_ent
|
|
|
|
| 10 |
clean_entities,
|
| 11 |
create_dense_embeddings,
|
| 12 |
create_sparse_embeddings,
|
| 13 |
+
extract_quarter_year,
|
| 14 |
+
extract_ticker_spacy,
|
| 15 |
format_query,
|
| 16 |
get_flan_alpaca_xl_model,
|
| 17 |
generate_alpaca_ner_prompt,
|
|
|
|
| 71 |
if ner_choice == "Alpaca":
|
| 72 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
| 73 |
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
|
| 74 |
+
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(
|
| 75 |
+
entity_text
|
| 76 |
+
)
|
| 77 |
else:
|
| 78 |
+
company_ent = extract_ticker_spacy(query_text, ner_model)
|
| 79 |
+
quarter_ent, year_ent = extract_quarter_year(query_text)
|
| 80 |
|
| 81 |
ticker_index, quarter_index, year_index = clean_entities(
|
| 82 |
company_ent, quarter_ent, year_ent
|
utils.py
CHANGED
|
@@ -5,6 +5,7 @@ import requests
|
|
| 5 |
import openai
|
| 6 |
import pandas as pd
|
| 7 |
import spacy
|
|
|
|
| 8 |
import streamlit_scrollable_textbox as stx
|
| 9 |
import torch
|
| 10 |
from sentence_transformers import SentenceTransformer
|
|
@@ -33,13 +34,17 @@ def get_data():
|
|
| 33 |
|
| 34 |
@st.experimental_singleton
|
| 35 |
def get_spacy_model():
|
| 36 |
-
return spacy.load("
|
| 37 |
|
| 38 |
|
| 39 |
@st.experimental_singleton
|
| 40 |
def get_flan_alpaca_xl_model():
|
| 41 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return model, tokenizer
|
| 44 |
|
| 45 |
|
|
@@ -478,6 +483,7 @@ Answer:?"""
|
|
| 478 |
|
| 479 |
# Entity Extraction
|
| 480 |
|
|
|
|
| 481 |
def generate_alpaca_ner_prompt(query):
|
| 482 |
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
|
| 483 |
|
|
@@ -515,19 +521,27 @@ Company - Cisco, Quarter - none, Year - 2016
|
|
| 515 |
### Response:"""
|
| 516 |
return prompt
|
| 517 |
|
|
|
|
| 518 |
def generate_entities_flan_alpaca_inference_api(prompt):
|
| 519 |
API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
|
| 520 |
API_TOKEN = st.secrets["hg_key"]
|
| 521 |
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
| 522 |
payload = {
|
| 523 |
"inputs": prompt,
|
| 524 |
-
"parameters": {
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
}
|
| 527 |
try:
|
| 528 |
data = json.dumps(payload)
|
|
|
|
| 529 |
response = requests.request("POST", API_URL, data=data)
|
| 530 |
-
output = json.loads(response.content.decode("utf-8"))[0][
|
|
|
|
|
|
|
| 531 |
except:
|
| 532 |
output = ""
|
| 533 |
print(output)
|
|
@@ -536,7 +550,7 @@ def generate_entities_flan_alpaca_inference_api(prompt):
|
|
| 536 |
|
| 537 |
def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
| 538 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
| 539 |
-
input_ids =
|
| 540 |
generation_output = model.generate(
|
| 541 |
input_ids=input_ids,
|
| 542 |
temperature=0.1,
|
|
@@ -547,9 +561,9 @@ def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
|
| 547 |
return output
|
| 548 |
|
| 549 |
|
| 550 |
-
def format_entities_flan_alpaca(
|
| 551 |
"""
|
| 552 |
-
Extracts the text for each entity from the output generated by the
|
| 553 |
Flan-Alpaca model.
|
| 554 |
"""
|
| 555 |
try:
|
|
@@ -560,22 +574,22 @@ def format_entities_flan_alpaca(model_output):
|
|
| 560 |
year = None
|
| 561 |
try:
|
| 562 |
company = company_string.split(" - ")[1].lower()
|
| 563 |
-
company = None if company.lower() ==
|
| 564 |
except:
|
| 565 |
company = None
|
| 566 |
try:
|
| 567 |
quarter = quarter_string.split(" - ")[1]
|
| 568 |
-
quarter = None if quarter.lower() ==
|
| 569 |
|
| 570 |
except:
|
| 571 |
quarter = None
|
| 572 |
try:
|
| 573 |
year = year_string.split(" - ")[1]
|
| 574 |
-
year = None if year.lower() ==
|
| 575 |
|
| 576 |
except:
|
| 577 |
year = None
|
| 578 |
-
|
| 579 |
print((company, quarter, year))
|
| 580 |
return company, quarter, year
|
| 581 |
|
|
@@ -586,34 +600,27 @@ def extract_quarter_year(string):
|
|
| 586 |
if year_match:
|
| 587 |
year = year_match.group()
|
| 588 |
else:
|
| 589 |
-
|
| 590 |
|
| 591 |
# Extract quarter from string
|
| 592 |
quarter_match = re.search(r"Q\d", string)
|
| 593 |
if quarter_match:
|
| 594 |
quarter = "Q" + quarter_match.group()[1]
|
| 595 |
else:
|
| 596 |
-
|
| 597 |
|
| 598 |
return quarter, year
|
| 599 |
|
| 600 |
|
| 601 |
-
def
|
| 602 |
doc = model(query)
|
| 603 |
entities = {ent.label_: ent.text for ent in doc.ents}
|
|
|
|
| 604 |
if "ORG" in entities.keys():
|
| 605 |
company = entities["ORG"].lower()
|
| 606 |
-
if "DATE" in entities.keys():
|
| 607 |
-
quarter, year = extract_quarter_year(entities["DATE"])
|
| 608 |
-
return company, quarter, year
|
| 609 |
-
else:
|
| 610 |
-
return company, None, None
|
| 611 |
else:
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
return None, quarter, year
|
| 615 |
-
else:
|
| 616 |
-
return None, None, None
|
| 617 |
|
| 618 |
|
| 619 |
def clean_entities(company, quarter, year):
|
|
|
|
| 5 |
import openai
|
| 6 |
import pandas as pd
|
| 7 |
import spacy
|
| 8 |
+
import spacy_transformers
|
| 9 |
import streamlit_scrollable_textbox as stx
|
| 10 |
import torch
|
| 11 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 34 |
|
| 35 |
@st.experimental_singleton
|
| 36 |
def get_spacy_model():
|
| 37 |
+
return spacy.load("en_core_web_trf")
|
| 38 |
|
| 39 |
|
| 40 |
@st.experimental_singleton
|
| 41 |
def get_flan_alpaca_xl_model():
|
| 42 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 43 |
+
"/home/user/app/models/flan-alpaca-xl/"
|
| 44 |
+
)
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 46 |
+
"/home/user/app/models/flan-alpaca-xl/"
|
| 47 |
+
)
|
| 48 |
return model, tokenizer
|
| 49 |
|
| 50 |
|
|
|
|
| 483 |
|
| 484 |
# Entity Extraction
|
| 485 |
|
| 486 |
+
|
| 487 |
def generate_alpaca_ner_prompt(query):
|
| 488 |
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Use the following guidelines to extract the entities representing the Company, Quarter, and Year in the sentence.
|
| 489 |
|
|
|
|
| 521 |
### Response:"""
|
| 522 |
return prompt
|
| 523 |
|
| 524 |
+
|
| 525 |
def generate_entities_flan_alpaca_inference_api(prompt):
|
| 526 |
API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
|
| 527 |
API_TOKEN = st.secrets["hg_key"]
|
| 528 |
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
| 529 |
payload = {
|
| 530 |
"inputs": prompt,
|
| 531 |
+
"parameters": {
|
| 532 |
+
"do_sample": True,
|
| 533 |
+
"temperature": 0.1,
|
| 534 |
+
"max_length": 80,
|
| 535 |
+
},
|
| 536 |
+
"options": {"use_cache": False, "wait_for_model": True},
|
| 537 |
}
|
| 538 |
try:
|
| 539 |
data = json.dumps(payload)
|
| 540 |
+
# Key not used as headers=headers not passed
|
| 541 |
response = requests.request("POST", API_URL, data=data)
|
| 542 |
+
output = json.loads(response.content.decode("utf-8"))[0][
|
| 543 |
+
"generated_text"
|
| 544 |
+
]
|
| 545 |
except:
|
| 546 |
output = ""
|
| 547 |
print(output)
|
|
|
|
| 550 |
|
| 551 |
def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
| 552 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
| 553 |
+
input_ids = model_inputs["input_ids"]
|
| 554 |
generation_output = model.generate(
|
| 555 |
input_ids=input_ids,
|
| 556 |
temperature=0.1,
|
|
|
|
| 561 |
return output
|
| 562 |
|
| 563 |
|
| 564 |
+
def format_entities_flan_alpaca(values):
|
| 565 |
"""
|
| 566 |
+
Extracts the text for each entity from the output generated by the
|
| 567 |
Flan-Alpaca model.
|
| 568 |
"""
|
| 569 |
try:
|
|
|
|
| 574 |
year = None
|
| 575 |
try:
|
| 576 |
company = company_string.split(" - ")[1].lower()
|
| 577 |
+
company = None if company.lower() == "none" else company
|
| 578 |
except:
|
| 579 |
company = None
|
| 580 |
try:
|
| 581 |
quarter = quarter_string.split(" - ")[1]
|
| 582 |
+
quarter = None if quarter.lower() == "none" else quarter
|
| 583 |
|
| 584 |
except:
|
| 585 |
quarter = None
|
| 586 |
try:
|
| 587 |
year = year_string.split(" - ")[1]
|
| 588 |
+
year = None if year.lower() == "none" else year
|
| 589 |
|
| 590 |
except:
|
| 591 |
year = None
|
| 592 |
+
|
| 593 |
print((company, quarter, year))
|
| 594 |
return company, quarter, year
|
| 595 |
|
|
|
|
| 600 |
if year_match:
|
| 601 |
year = year_match.group()
|
| 602 |
else:
|
| 603 |
+
year = None
|
| 604 |
|
| 605 |
# Extract quarter from string
|
| 606 |
quarter_match = re.search(r"Q\d", string)
|
| 607 |
if quarter_match:
|
| 608 |
quarter = "Q" + quarter_match.group()[1]
|
| 609 |
else:
|
| 610 |
+
quarter = None
|
| 611 |
|
| 612 |
return quarter, year
|
| 613 |
|
| 614 |
|
| 615 |
+
def extract_ticker_spacy(query, model):
|
| 616 |
doc = model(query)
|
| 617 |
entities = {ent.label_: ent.text for ent in doc.ents}
|
| 618 |
+
print(entities.keys())
|
| 619 |
if "ORG" in entities.keys():
|
| 620 |
company = entities["ORG"].lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
else:
|
| 622 |
+
company = None
|
| 623 |
+
return company
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
|
| 626 |
def clean_entities(company, quarter, year):
|