|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import random |
|
|
import shutil |
|
|
import hashlib |
|
|
import uuid |
|
|
from typing import List |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
import time |
|
|
import threading |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from PIL import Image, ImageOps |
|
|
from matplotlib import cm |
|
|
|
|
|
import cv2 |
|
|
from fastapi import FastAPI, File, UploadFile, Form, Request, Depends |
|
|
from fastapi.responses import HTMLResponse, RedirectResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.dirname(__file__))) |
|
|
from models.densenet.preprocess.preprocessingwangchan import get_tokenizer, get_transforms |
|
|
from models.densenet.train_densenet_only import DenseNet121Classifier |
|
|
from models.densenet.train_text_only import TextClassifier |
|
|
torch.manual_seed(42); np.random.seed(42); random.seed(42) |
|
|
FUSION_LABELMAP_PATH = "models/densenet/label_map_fusion_densenet.json" |
|
|
FUSION_WEIGHTS_PATH = "models/densenet/best_fusion_densenet.pth" |
|
|
with open(FUSION_LABELMAP_PATH, "r", encoding="utf-8") as f: |
|
|
label_map = json.load(f) |
|
|
class_names = [label for label, _ in sorted(label_map.items(), key=lambda x: x[1])] |
|
|
NUM_CLASSES = len(class_names) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"🧠 Using device: {device}") |
|
|
class FusionDenseNetText(nn.Module): |
|
|
def __init__(self, num_classes, dropout=0.3): |
|
|
super().__init__() |
|
|
self.image_model = DenseNet121Classifier(num_classes=num_classes) |
|
|
self.text_model = TextClassifier(num_classes=num_classes) |
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(num_classes * 2, 128), nn.ReLU(), |
|
|
nn.Dropout(dropout), nn.Linear(128, num_classes) |
|
|
) |
|
|
def forward(self, image, input_ids, attention_mask): |
|
|
logits_img = self.image_model(image) |
|
|
logits_txt = self.text_model(input_ids, attention_mask) |
|
|
fused_in = torch.cat([logits_img, logits_txt], dim=1) |
|
|
fused_out = self.fusion(fused_in) |
|
|
return fused_out, logits_img, logits_txt |
|
|
print("🔄 Loading AI model...") |
|
|
fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device) |
|
|
fusion_model.load_state_dict(torch.load(FUSION_WEIGHTS_PATH, map_location=device)) |
|
|
fusion_model.eval() |
|
|
print("✅ AI Model loaded successfully!") |
|
|
tokenizer = get_tokenizer() |
|
|
transform = get_transforms((224, 224)) |
|
|
def _find_last_conv2d(mod: torch.nn.Module): |
|
|
last = None |
|
|
for m in mod.modules(): |
|
|
if isinstance(m, torch.nn.Conv2d): last = m |
|
|
return last |
|
|
def compute_gradcam_overlay(img_pil, image_tensor, target_class_idx): |
|
|
img_branch = fusion_model.image_model |
|
|
target_layer = _find_last_conv2d(img_branch) |
|
|
if target_layer is None: return None |
|
|
activations, gradients = [], [] |
|
|
def fwd_hook(_m, _i, o): activations.append(o) |
|
|
def bwd_hook(_m, gin, gout): gradients.append(gout[0]) |
|
|
h1 = target_layer.register_forward_hook(fwd_hook) |
|
|
h2 = target_layer.register_full_backward_hook(bwd_hook) |
|
|
try: |
|
|
img_branch.zero_grad() |
|
|
logits_img = img_branch(image_tensor) |
|
|
score = logits_img[0, target_class_idx] |
|
|
score.backward() |
|
|
act = activations[-1].detach()[0] |
|
|
grad = gradients[-1].detach()[0] |
|
|
weights = torch.mean(grad, dim=(1, 2)) |
|
|
cam = torch.relu(torch.sum(weights[:, None, None] * act, dim=0)) |
|
|
cam -= cam.min(); cam /= (cam.max() + 1e-8) |
|
|
cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)).resize(img_pil.size, Image.BILINEAR) |
|
|
cam_np = np.asarray(cam_img).astype(np.float32) / 255.0 |
|
|
heatmap = cm.get_cmap("jet")(cam_np)[:, :, :3] |
|
|
img_np = np.asarray(img_pil.convert("RGB")).astype(np.float32) / 255.0 |
|
|
overlay = (0.6 * img_np + 0.4 * heatmap) |
|
|
return np.clip(overlay * 255, 0, 255).astype(np.uint8) |
|
|
finally: |
|
|
h1.remove(); h2.remove(); img_branch.zero_grad() |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
templates = Jinja2Templates(directory="templates") |
|
|
os.makedirs("uploads", exist_ok=True) |
|
|
|
|
|
EXPIRATION_MINUTES = 10 |
|
|
results_cache = {} |
|
|
cache_lock = threading.Lock() |
|
|
|
|
|
def cleanup_expired_cache(): |
|
|
""" |
|
|
ฟังก์ชันนี้จะทำงานใน Background Thread เพื่อตรวจสอบและลบ Cache ที่หมดอายุ |
|
|
""" |
|
|
while True: |
|
|
with cache_lock: |
|
|
|
|
|
expired_keys = [] |
|
|
current_time = time.time() |
|
|
for key, value in results_cache.items(): |
|
|
if current_time - value["created_at"] > EXPIRATION_MINUTES * 60: |
|
|
expired_keys.append(key) |
|
|
|
|
|
|
|
|
for key in expired_keys: |
|
|
del results_cache[key] |
|
|
print(f"🧹 Cache expired and removed for key: {key}") |
|
|
|
|
|
time.sleep(60) |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
""" |
|
|
เริ่ม Background Thread สำหรับทำความสะอาด Cache เมื่อแอปเริ่มทำงาน |
|
|
""" |
|
|
cleanup_thread = threading.Thread(target=cleanup_expired_cache, daemon=True) |
|
|
cleanup_thread.start() |
|
|
print("🗑️ Cache cleanup task started.") |
|
|
|
|
|
SYMPTOM_MAP = { |
|
|
"noSymptoms": "ไม่มีอาการ", "drinkAlcohol": "ดื่มเหล้า", "smoking": "สูบบุหรี่", |
|
|
"chewBetelNut": "เคี้ยวหมาก", "eatSpicyFood": "กินเผ็ดแสบ", "wipeOff": "เช็ดออกได้", |
|
|
"alwaysHurts": "เจ็บเมื่อโดนแผล" |
|
|
} |
|
|
def process_with_ai_model(image_path: str, prompt_text: str): |
|
|
try: |
|
|
image_pil = Image.open(image_path) |
|
|
image_pil = ImageOps.exif_transpose(image_pil) |
|
|
image_pil = image_pil.convert("RGB") |
|
|
image_tensor = transform(image_pil).unsqueeze(0).to(device) |
|
|
enc = tokenizer(prompt_text, return_tensors="pt", padding="max_length", |
|
|
truncation=True, max_length=128) |
|
|
ids, mask = enc["input_ids"].to(device), enc["attention_mask"].to(device) |
|
|
with torch.no_grad(): |
|
|
fused_logits, _, _ = fusion_model(image_tensor, ids, mask) |
|
|
probs_fused = torch.softmax(fused_logits, dim=1)[0].cpu().numpy() |
|
|
pred_idx = int(np.argmax(probs_fused)) |
|
|
pred_label = class_names[pred_idx] |
|
|
confidence = float(probs_fused[pred_idx]) * 100 |
|
|
gradcam_overlay_np = compute_gradcam_overlay(image_pil, image_tensor, pred_idx) |
|
|
def image_to_base64(img): |
|
|
buffered = BytesIO() |
|
|
img.save(buffered, format="JPEG") |
|
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
original_b64 = image_to_base64(image_pil) |
|
|
if gradcam_overlay_np is not None: |
|
|
gradcam_pil = Image.fromarray(gradcam_overlay_np) |
|
|
gradcam_b64 = image_to_base64(gradcam_pil) |
|
|
else: |
|
|
gradcam_b64 = original_b64 |
|
|
return original_b64, gradcam_b64, pred_label, f"{confidence:.2f}" |
|
|
except Exception as e: |
|
|
print(f"❌ Error during AI processing: {e}") |
|
|
return None, None, "Error", "0.00" |
|
|
|
|
|
@app.get("/", response_class=RedirectResponse) |
|
|
async def root(): |
|
|
return RedirectResponse(url="/detect") |
|
|
@app.get("/detect", response_class=HTMLResponse) |
|
|
async def show_upload_form(request: Request): |
|
|
return templates.TemplateResponse("detect.html", {"request": request}) |
|
|
|
|
|
@app.post("/uploaded") |
|
|
async def handle_upload( |
|
|
request: Request, |
|
|
file: UploadFile = File(...), |
|
|
checkboxes: List[str] = Form([]), |
|
|
symptom_text: str = Form("") |
|
|
): |
|
|
temp_filepath = os.path.join("uploads", f"{uuid.uuid4()}_{file.filename}") |
|
|
with open(temp_filepath, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
final_prompt_parts = [] |
|
|
selected_symptoms_thai = {SYMPTOM_MAP.get(cb) for cb in checkboxes if SYMPTOM_MAP.get(cb)} |
|
|
if "ไม่มีอาการ" in selected_symptoms_thai: |
|
|
symptoms_group = {"เจ็บเมื่อโดนแผล", "กินเผ็ดแสบ"} |
|
|
lifestyles_group = {"ดื่มเหล้า", "สูบบุหรี่", "เคี้ยวหมาก"} |
|
|
patterns_group = {"เช็ดออกได้"} |
|
|
special_group = {"ไม่มีอาการ"} |
|
|
final_selected = (selected_symptoms_thai - symptoms_group) | \ |
|
|
(selected_symptoms_thai & (lifestyles_group | patterns_group | special_group)) |
|
|
final_prompt_parts.append(" ".join(sorted(list(final_selected)))) |
|
|
elif selected_symptoms_thai: |
|
|
final_prompt_parts.append(" ".join(sorted(list(selected_symptoms_thai)))) |
|
|
if symptom_text and symptom_text.strip(): |
|
|
final_prompt_parts.append(symptom_text.strip()) |
|
|
final_prompt = "; ".join(final_prompt_parts) if final_prompt_parts else "ไม่มีอาการ" |
|
|
image_b64, gradcam_b64, name_out, eva_output = process_with_ai_model( |
|
|
image_path=temp_filepath, prompt_text=final_prompt |
|
|
) |
|
|
os.remove(temp_filepath) |
|
|
result_id = str(uuid.uuid4()) |
|
|
result_data = { |
|
|
"image_b64_data": image_b64, "gradcam_b64_data": gradcam_b64, |
|
|
"name_out": name_out, "eva_output": eva_output, |
|
|
} |
|
|
with cache_lock: |
|
|
results_cache[result_id] = { |
|
|
"data": result_data, |
|
|
"created_at": time.time() |
|
|
} |
|
|
|
|
|
results_url = request.url_for('show_results', result_id=result_id) |
|
|
return RedirectResponse(url=results_url, status_code=303) |
|
|
|
|
|
@app.get("/results/{result_id}", response_class=HTMLResponse) |
|
|
async def show_results(request: Request, result_id: str): |
|
|
with cache_lock: |
|
|
cached_item = results_cache.get(result_id) |
|
|
if not cached_item or (time.time() - cached_item["created_at"] > EXPIRATION_MINUTES * 60): |
|
|
if cached_item: |
|
|
with cache_lock: |
|
|
del results_cache[result_id] |
|
|
return RedirectResponse(url="/detect") |
|
|
|
|
|
context = {"request": request, **cached_item["data"]} |
|
|
return templates.TemplateResponse("detect.html", context) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |