|
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
import torch |
|
|
import logging |
|
|
from .utils import chunk_text |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BookSummarizer: |
|
|
""" |
|
|
Handles AI-powered text summarization using transformer models. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_name: str = "facebook/bart-large-cnn"): |
|
|
""" |
|
|
Initialize the summarizer with a specific model. |
|
|
|
|
|
Args: |
|
|
model_name: Hugging Face model name for summarization |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.summarizer = None |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
logger.info(f"Initializing summarizer with model: {model_name}") |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
def load_model(self): |
|
|
""" |
|
|
Load the summarization model and tokenizer. |
|
|
""" |
|
|
try: |
|
|
logger.info("Loading summarization model...") |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
self.summarizer = pipeline( |
|
|
"summarization", |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
device=0 if self.device == "cuda" else -1 |
|
|
) |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {str(e)}") |
|
|
raise |
|
|
|
|
|
def summarize_text(self, text: str, max_length: int = 150, min_length: int = 50, |
|
|
do_sample: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Summarize a single text chunk. |
|
|
|
|
|
Args: |
|
|
text: Text to summarize |
|
|
max_length: Maximum length of summary |
|
|
min_length: Minimum length of summary |
|
|
do_sample: Whether to use sampling for generation |
|
|
|
|
|
Returns: |
|
|
Dictionary containing summary and metadata |
|
|
""" |
|
|
try: |
|
|
if not self.summarizer: |
|
|
self.load_model() |
|
|
|
|
|
|
|
|
if len(text.split()) < 50: |
|
|
return { |
|
|
'success': True, |
|
|
'summary': text, |
|
|
'original_length': len(text.split()), |
|
|
'summary_length': len(text.split()), |
|
|
'compression_ratio': 1.0 |
|
|
} |
|
|
|
|
|
|
|
|
summary_result = self.summarizer( |
|
|
text, |
|
|
max_length=max_length, |
|
|
min_length=min_length, |
|
|
do_sample=do_sample, |
|
|
truncation=True |
|
|
) |
|
|
|
|
|
summary = summary_result[0]['summary_text'] |
|
|
|
|
|
|
|
|
original_words = len(text.split()) |
|
|
summary_words = len(summary.split()) |
|
|
compression_ratio = summary_words / original_words if original_words > 0 else 0 |
|
|
|
|
|
return { |
|
|
'success': True, |
|
|
'summary': summary, |
|
|
'original_length': original_words, |
|
|
'summary_length': summary_words, |
|
|
'compression_ratio': compression_ratio |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error summarizing text: {str(e)}") |
|
|
return { |
|
|
'success': False, |
|
|
'summary': '', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def summarize_book(self, text: str, chunk_size: int = 1000, overlap: int = 100, |
|
|
max_length: int = 150, min_length: int = 50) -> Dict[str, Any]: |
|
|
""" |
|
|
Summarize a complete book by processing it in chunks. |
|
|
|
|
|
Args: |
|
|
text: Complete book text |
|
|
chunk_size: Size of each text chunk |
|
|
overlap: Overlap between chunks |
|
|
max_length: Maximum length of each summary |
|
|
min_length: Minimum length of each summary |
|
|
|
|
|
Returns: |
|
|
Dictionary containing complete summary and metadata |
|
|
""" |
|
|
try: |
|
|
logger.info("Starting book summarization...") |
|
|
|
|
|
|
|
|
chunks = chunk_text(text, chunk_size, overlap) |
|
|
logger.info(f"Split text into {len(chunks)} chunks") |
|
|
|
|
|
|
|
|
chunk_summaries = [] |
|
|
total_original_words = 0 |
|
|
total_summary_words = 0 |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
logger.info(f"Processing chunk {i+1}/{len(chunks)}") |
|
|
|
|
|
result = self.summarize_text(chunk, max_length, min_length) |
|
|
|
|
|
if result['success']: |
|
|
chunk_summaries.append(result['summary']) |
|
|
total_original_words += result['original_length'] |
|
|
total_summary_words += result['summary_length'] |
|
|
else: |
|
|
logger.warning(f"Failed to summarize chunk {i+1}: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
chunk_summaries.append(chunk[:200] + "...") |
|
|
|
|
|
|
|
|
combined_summary = " ".join(chunk_summaries) |
|
|
|
|
|
|
|
|
if len(combined_summary.split()) > 500: |
|
|
logger.info("Creating final summary from combined summaries...") |
|
|
final_result = self.summarize_text(combined_summary, max_length=300, min_length=100) |
|
|
if final_result['success']: |
|
|
combined_summary = final_result['summary'] |
|
|
|
|
|
|
|
|
overall_compression = total_summary_words / total_original_words if total_original_words > 0 else 0 |
|
|
|
|
|
return { |
|
|
'success': True, |
|
|
'summary': combined_summary, |
|
|
'statistics': { |
|
|
'total_chunks': len(chunks), |
|
|
'total_original_words': total_original_words, |
|
|
'total_summary_words': total_summary_words, |
|
|
'overall_compression_ratio': overall_compression, |
|
|
'final_summary_length': len(combined_summary.split()) |
|
|
}, |
|
|
'chunk_summaries': chunk_summaries |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in book summarization: {str(e)}") |
|
|
return { |
|
|
'success': False, |
|
|
'summary': '', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def get_available_models(self) -> List[Dict[str, Union[str, int]]]: |
|
|
""" |
|
|
Get list of available summarization models. |
|
|
""" |
|
|
return [ |
|
|
{ |
|
|
'name': 'facebook/bart-large-cnn', |
|
|
'description': 'BART model fine-tuned on CNN news articles (recommended)', |
|
|
'max_length': 1024 |
|
|
}, |
|
|
{ |
|
|
'name': 't5-small', |
|
|
'description': 'Small T5 model, faster but less accurate', |
|
|
'max_length': 512 |
|
|
}, |
|
|
{ |
|
|
'name': 'facebook/bart-base', |
|
|
'description': 'Base BART model, balanced performance', |
|
|
'max_length': 1024 |
|
|
} |
|
|
] |
|
|
|
|
|
def change_model(self, model_name: str): |
|
|
""" |
|
|
Change the summarization model. |
|
|
|
|
|
Args: |
|
|
model_name: New model name to use |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.summarizer = None |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
logger.info(f"Model changed to: {model_name}") |