Files
2025-08-26 13:24:53 -07:00

187 lines
5.8 KiB
Python

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from functools import cache
import sys
import os
import argparse
import logging
import datetime
import re
import string
import unicodedata
from PIL import Image
import time
import random
from io import BytesIO
import base64
import requests
import shutil
import tempfile
import copy
import json
from flask import Flask, request, jsonify
# Import utils for image processing
from utils import image_to_url, generate_image_hash
# Optional: Import model downloader
from model_downloader import download_vl_model
# Get environment variables
VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "Qwen/Qwen2.5-VL-7B-Instruct")
LOCAL_EMBED_MODEL_ID = os.getenv("LOCAL_EMBED_MODEL_ID", "CLIP-ViT-H-14")
MODEL_DIR = os.getenv("MODEL_DIR", "./models")
DOWNLOAD_MODELS = os.getenv("DOWNLOAD_MODELS", "True").lower() == "true"
# Configure logging
logger = logging.getLogger('vlm_backend')
logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(asctime)s.%(msecs)03d [%(name)s]: %(message)s",
datefmt='%Y-%m-%d %H:%M:%S'
)
# Create Flask application
app = Flask(__name__)
# Global model variables
ov_model = None
tokenizer = None
# Load models
def load_models():
global ov_model, tokenizer
logger.info(f"Loading VLM model: {VLM_MODEL_NAME}")
try:
# Check if models should be downloaded
if DOWNLOAD_MODELS:
logger.info("Auto-downloading model is enabled")
model_path = download_vl_model(VLM_MODEL_NAME, MODEL_DIR, "FP16")
logger.info(f"Model downloaded to {model_path}")
# In a real implementation, you would load the OpenVINO model here:
# from optimum.intel.openvino import OVModelForCausalLM
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# ov_model = OVModelForCausalLM.from_pretrained(model_path)
else:
logger.info("Using pre-downloaded model")
# Similarly, you would load your pre-downloaded model here
logger.info("Models loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading models: {e}")
return False
# Simulate VLM processing
def process_image_query(image_data, query, task="vqa"):
"""
Process an image with a query using the VLM model.
Args:
image_data: Base64 encoded image
query: Text query to process
task: The task to perform (vqa or search)
Returns:
Dictionary with the processing results
"""
try:
# Decode the image
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
# Log the request
logger.info(f"Processing {task} request with query: {query}")
logger.info(f"Image size: {image.size}")
# Generate a simulated response
if task == "vqa":
# Simulate a Visual Question Answering response
response = {
"answer": f"This is a simulated VLM response for: '{query}'. In a real implementation, this would be generated by a Vision Language Model based on the image content."
}
else:
# Simulate a Visual Search response
response = {
"answer": f"Search results for: '{query}'. Found 5 similar images (simulated results)."
}
# Add a slight delay to simulate processing time
time.sleep(1.5)
return response
except Exception as e:
logger.error(f"Error processing image: {str(e)}")
return {"error": str(e)}
# API endpoints
@app.route('/vqa', methods=['POST'])
def vqa_endpoint():
try:
# Get data from request
data = request.json
if not data:
return jsonify({"error": "No data provided"}), 400
# Extract image and query
image_data = data.get('image')
query = data.get('query')
if not image_data or not query:
return jsonify({"error": "Image and query are required"}), 400
# Process the request
result = process_image_query(image_data, query, task="vqa")
return jsonify(result)
except Exception as e:
logger.error(f"Error in VQA endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/search', methods=['POST'])
def search_endpoint():
try:
# Get data from request
data = request.json
if not data:
return jsonify({"error": "No data provided"}), 400
# Extract image and query
image_data = data.get('image')
query = data.get('query')
if not image_data or not query:
return jsonify({"error": "Image and query are required"}), 400
# Process the request
result = process_image_query(image_data, query, task="search")
return jsonify(result)
except Exception as e:
logger.error(f"Error in search endpoint: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "ok", "model": VLM_MODEL_NAME})
if __name__ == '__main__':
# Parse command line arguments
parser = argparse.ArgumentParser(description='Vision Language Model Backend')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to run the server on')
parser.add_argument('--port', type=int, default=8399, help='Port to run the server on')
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
args = parser.parse_args()
# Load models
load_models()
# Run the Flask application
app.run(host=args.host, port=args.port, debug=args.debug)