File size: 2,931 Bytes
4f267c9 539ede7 4f267c9 539ede7 4f267c9 643a3e1 4f267c9 643a3e1 4f267c9 643a3e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Custom handler for Llama 2 text-generation model.
Author: Henry
Created on: Mon Nov 20, 2023
This module defines a custom handler for the Llama 2 text-generation model,
utilizing Hugging Face's transformers pipeline. It's designed to process requests
for text generation, leveraging the capabilities of the Llama 2 model.
"""
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, pipeline, BitsAndBytesConfig
from typing import Dict, List, Any
import logging
import sys
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(asctime)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
class EndpointHandler:
"""
Handler class for Llama 2 text-generation model inference.
This class initializes the model pipeline and processes incoming requests
for text generation using the Llama 2 model.
"""
def __init__(self, path: str = ""):
"""
Initialize the pipeline for the Llama 2 text-generation model.
Args:
path (str): Path to the model, defaults to an empty string.
"""
# Set the global default compute type to float16
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = LlamaTokenizer.from_pretrained(path)
model = LlamaForCausalLM.from_pretrained(path, device_map=0, quantization_config=self.bnb_config)
self.pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process a request for text generation.
Args:
data (Dict[str, Any]): A dictionary containing inputs for text generation.
Returns:
List[Dict[str, Any]]: The generated text as a list of dictionaries.
"""
# Log the received data
logging.info(f"Received data: {data}")
# Get dictionary
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# Validate the input data
if not inputs:
raise ValueError(f'inputs are required and \'{inputs}\' is gotten.')
# Log the extracted image and question for debugging
logging.info(f"Extracted inputs: {inputs}")
logging.info(f"Extracted parameters: {parameters}")
# Perform the question answering using the model
# prediction = self.pipeline(inputs)
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline(inputs, **parameters)
else:
prediction = self.pipeline(inputs)
return prediction
|