""" Farm Segmentation API - Gradio Interface SegFormer and UNet models for agricultural image segmentation """ import gradio as gr import torch import cv2 import numpy as np from PIL import Image import json import base64 import io import time from typing import List, Dict, Any # Import segmentation models try: from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation MODELS_AVAILABLE = True except ImportError: MODELS_AVAILABLE = False class SegmentationAPI: def __init__(self): self.models = {} self.processors = {} self.model_configs = { "segformer_b0": "nvidia/segformer-b0-finetuned-ade-512-512", "segformer_b1": "nvidia/segformer-b1-finetuned-ade-512-512", "segformer_b2": "nvidia/segformer-b2-finetuned-ade-512-512" } # Segmentation classes relevant to agriculture self.ag_classes = { "soil": ["dirt", "earth", "ground", "soil", "mud"], "vegetation": ["grass", "tree", "plant", "leaf", "crop", "vegetation"], "water": ["water", "river", "pond", "irrigation"], "sky": ["sky", "cloud"], "building": ["building", "structure", "barn", "greenhouse"], "road": ["road", "path", "walkway"], "equipment": ["machine", "tractor", "equipment"] } if MODELS_AVAILABLE: self.load_models() def load_models(self): """Load segmentation models""" for model_key, model_name in self.model_configs.items(): try: print(f"Loading {model_name}...") processor = SegformerImageProcessor.from_pretrained(model_name) model = SegformerForSemanticSegmentation.from_pretrained(model_name) self.processors[model_key] = processor self.models[model_key] = model print(f"āœ… {model_name} loaded successfully") except Exception as e: print(f"āŒ Failed to load {model_name}: {e}") def segment_image(self, image: Image.Image, model_key: str = "segformer_b1") -> Dict[str, Any]: """Segment agricultural image""" if not MODELS_AVAILABLE or model_key not in self.models: return {"error": "Model not available"} start_time = time.time() try: # Preprocess image processor = self.processors[model_key] model = self.models[model_key] inputs = processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = model(**inputs) # Post-process segmentation logits = outputs.logits upsampled_logits = torch.nn.functional.interpolate( logits, size=image.size[::-1], mode="bilinear", align_corners=False, ) predicted_segmentation = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy() # Analyze segments segments_info = self.analyze_segments(predicted_segmentation, model) # Create colored segmentation map colored_segmentation = self.create_colored_segmentation(predicted_segmentation, model) processing_time = time.time() - start_time return { "segments_detected": len(segments_info), "segments": segments_info, "segmentation_map": colored_segmentation, "processing_time": round(processing_time, 2), "model_used": model_key } except Exception as e: return {"error": str(e)} def analyze_segments(self, segmentation: np.ndarray, model) -> List[Dict[str, Any]]: """Analyze segmentation results""" unique_labels = np.unique(segmentation) segments_info = [] total_pixels = segmentation.size for label in unique_labels: mask = segmentation == label pixel_count = np.sum(mask) percentage = (pixel_count / total_pixels) * 100 if percentage > 1.0: # Only include segments > 1% class_name = model.config.id2label.get(label, f"class_{label}") ag_category = self.classify_agricultural_segment(class_name) segments_info.append({ "class": class_name, "agricultural_category": ag_category, "pixel_count": int(pixel_count), "percentage": round(percentage, 2), "label_id": int(label) }) return sorted(segments_info, key=lambda x: x["percentage"], reverse=True) def classify_agricultural_segment(self, class_name: str) -> str: """Classify segment into agricultural categories""" class_lower = class_name.lower() for ag_category, keywords in self.ag_classes.items(): if any(keyword in class_lower for keyword in keywords): return ag_category return "other" def create_colored_segmentation(self, segmentation: np.ndarray, model) -> np.ndarray: """Create colored segmentation visualization""" # Create color palette num_classes = len(model.config.id2label) colors = self.generate_colors(num_classes) # Create colored image h, w = segmentation.shape colored = np.zeros((h, w, 3), dtype=np.uint8) for label in np.unique(segmentation): mask = segmentation == label colored[mask] = colors[label % len(colors)] return colored def generate_colors(self, num_colors: int) -> List[List[int]]: """Generate distinct colors for segmentation classes""" import random random.seed(42) # For consistent colors colors = [] for i in range(num_colors): colors.append([ random.randint(50, 255), random.randint(50, 255), random.randint(50, 255) ]) return colors # Initialize API api = SegmentationAPI() def predict_segmentation(image, model_choice): """Gradio prediction function""" if image is None: return None, None, "Please upload an image" # Convert to PIL Image if isinstance(image, np.ndarray): image = Image.fromarray(image) # Run segmentation results = api.segment_image(image, model_choice) if "error" in results: return None, None, f"Error: {results['error']}" # Create visualization segmentation_img = Image.fromarray(results["segmentation_map"]) # Format results text results_text = f""" šŸžļø **Agricultural Segmentation Analysis** šŸ“Š **Segments Detected**: {results['segments_detected']} ā±ļø **Processing Time**: {results['processing_time']}s šŸ¤– **Model**: {results['model_used']} **🌾 Agricultural Composition**: """ for segment in results["segments"][:10]: # Top 10 segments results_text += f"\n• **{segment['class']}** ({segment['agricultural_category']}): {segment['percentage']:.1f}%" return image, segmentation_img, results_text # Gradio Interface with gr.Blocks(title="šŸžļø Farm Segmentation API") as app: gr.Markdown("# šŸžļø Farm Segmentation API") gr.Markdown("AI-powered agricultural image segmentation and land analysis") with gr.Tab("🌾 Field Analysis"): with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Farm Image") model_choice = gr.Dropdown( choices=["segformer_b0", "segformer_b1", "segformer_b2"], value="segformer_b1", label="Select Model" ) segment_btn = gr.Button("šŸ” Analyze Segments", variant="primary") with gr.Column(): original_image = gr.Image(label="Original Image") segmented_image = gr.Image(label="Segmentation Map") results_text = gr.Textbox(label="Segmentation Analysis", lines=15) segment_btn.click( predict_segmentation, inputs=[image_input, model_choice], outputs=[original_image, segmented_image, results_text] ) with gr.Tab("šŸ“” API Documentation"): gr.Markdown(""" ## šŸš€ API Endpoint **POST** `/api/predict` ### Request Format ```json { "data": ["", ""] } ``` ### Model Options - **segformer_b0**: Fastest, basic segmentation - **segformer_b1**: Balanced speed and accuracy (recommended) - **segformer_b2**: Higher accuracy, slower processing ### Response Format ```json { "segments_detected": 8, "segments": [ { "class": "grass", "agricultural_category": "vegetation", "pixel_count": 145632, "percentage": 35.2, "label_id": 9 } ], "processing_time": 2.1 } ``` ### Agricultural Categories - **soil**: Ground, dirt, earth - **vegetation**: Crops, grass, trees - **water**: Irrigation, ponds, rivers - **building**: Barns, greenhouses, structures - **equipment**: Tractors, machinery - **other**: Uncategorized segments """) if __name__ == "__main__": app.launch()