| import torch | |
| from typing import Dict, List, Any | |
| from transformers import AutoTokenizer, BitsAndBytesConfig | |
| from peft import AutoPeftModelForCausalLM | |
| def parse_output(text): | |
| marker = "### Response:" | |
| if marker in text: | |
| pos = text.find(marker) + len(marker) | |
| else: | |
| pos = 0 | |
| return text[pos:].replace("<pad>", "").replace("</s>", "").strip() | |
| class EndpointHandler: | |
| def __init__(self, path="./", use_bnb=True): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| self.model = AutoPeftModelForCausalLM.from_pretrained( | |
| path, load_in_8bit=False, quantization_config=bnb_config | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| def __call__(self, data: Any) -> List[List[Dict[str, str]]]: | |
| inputs = data.get("inputs", data) | |
| prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: \n{inputs}\n\n### Response: \n" | |
| parameters = data.get("parameters", {}) | |
| inputs = self.tokenizer( | |
| prompt, return_tensors="pt", return_token_type_ids=False | |
| ).to(self.model.device) | |
| outputs = self.model.generate(**inputs, **parameters) | |
| return { | |
| "generated_text": parse_output( | |
| self.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True) | |
| ) | |
| } | |