187 lines
5.8 KiB
Python
187 lines
5.8 KiB
Python
# Copyright (C) 2025 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from pathlib import Path
|
|
from functools import cache
|
|
import sys
|
|
import os
|
|
import argparse
|
|
import logging
|
|
import datetime
|
|
import re
|
|
import string
|
|
import unicodedata
|
|
from PIL import Image
|
|
import time
|
|
import random
|
|
from io import BytesIO
|
|
import base64
|
|
import requests
|
|
import shutil
|
|
import tempfile
|
|
import copy
|
|
import json
|
|
from flask import Flask, request, jsonify
|
|
|
|
# Import utils for image processing
|
|
from utils import image_to_url, generate_image_hash
|
|
|
|
# Optional: Import model downloader
|
|
from model_downloader import download_vl_model
|
|
|
|
# Get environment variables
|
|
VLM_MODEL_NAME = os.getenv("VLM_MODEL_NAME", "Qwen/Qwen2.5-VL-7B-Instruct")
|
|
LOCAL_EMBED_MODEL_ID = os.getenv("LOCAL_EMBED_MODEL_ID", "CLIP-ViT-H-14")
|
|
MODEL_DIR = os.getenv("MODEL_DIR", "./models")
|
|
DOWNLOAD_MODELS = os.getenv("DOWNLOAD_MODELS", "True").lower() == "true"
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger('vlm_backend')
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[%(levelname)s] %(asctime)s.%(msecs)03d [%(name)s]: %(message)s",
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
|
|
# Create Flask application
|
|
app = Flask(__name__)
|
|
|
|
# Global model variables
|
|
ov_model = None
|
|
tokenizer = None
|
|
|
|
# Load models
|
|
def load_models():
|
|
global ov_model, tokenizer
|
|
|
|
logger.info(f"Loading VLM model: {VLM_MODEL_NAME}")
|
|
|
|
try:
|
|
# Check if models should be downloaded
|
|
if DOWNLOAD_MODELS:
|
|
logger.info("Auto-downloading model is enabled")
|
|
model_path = download_vl_model(VLM_MODEL_NAME, MODEL_DIR, "FP16")
|
|
logger.info(f"Model downloaded to {model_path}")
|
|
|
|
# In a real implementation, you would load the OpenVINO model here:
|
|
# from optimum.intel.openvino import OVModelForCausalLM
|
|
# from transformers import AutoTokenizer
|
|
# tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
# ov_model = OVModelForCausalLM.from_pretrained(model_path)
|
|
else:
|
|
logger.info("Using pre-downloaded model")
|
|
# Similarly, you would load your pre-downloaded model here
|
|
|
|
logger.info("Models loaded successfully")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error loading models: {e}")
|
|
return False
|
|
|
|
# Simulate VLM processing
|
|
def process_image_query(image_data, query, task="vqa"):
|
|
"""
|
|
Process an image with a query using the VLM model.
|
|
|
|
Args:
|
|
image_data: Base64 encoded image
|
|
query: Text query to process
|
|
task: The task to perform (vqa or search)
|
|
|
|
Returns:
|
|
Dictionary with the processing results
|
|
"""
|
|
try:
|
|
# Decode the image
|
|
image_bytes = base64.b64decode(image_data)
|
|
image = Image.open(BytesIO(image_bytes))
|
|
|
|
# Log the request
|
|
logger.info(f"Processing {task} request with query: {query}")
|
|
logger.info(f"Image size: {image.size}")
|
|
|
|
# Generate a simulated response
|
|
if task == "vqa":
|
|
# Simulate a Visual Question Answering response
|
|
response = {
|
|
"answer": f"This is a simulated VLM response for: '{query}'. In a real implementation, this would be generated by a Vision Language Model based on the image content."
|
|
}
|
|
else:
|
|
# Simulate a Visual Search response
|
|
response = {
|
|
"answer": f"Search results for: '{query}'. Found 5 similar images (simulated results)."
|
|
}
|
|
|
|
# Add a slight delay to simulate processing time
|
|
time.sleep(1.5)
|
|
|
|
return response
|
|
except Exception as e:
|
|
logger.error(f"Error processing image: {str(e)}")
|
|
return {"error": str(e)}
|
|
|
|
# API endpoints
|
|
@app.route('/vqa', methods=['POST'])
|
|
def vqa_endpoint():
|
|
try:
|
|
# Get data from request
|
|
data = request.json
|
|
if not data:
|
|
return jsonify({"error": "No data provided"}), 400
|
|
|
|
# Extract image and query
|
|
image_data = data.get('image')
|
|
query = data.get('query')
|
|
|
|
if not image_data or not query:
|
|
return jsonify({"error": "Image and query are required"}), 400
|
|
|
|
# Process the request
|
|
result = process_image_query(image_data, query, task="vqa")
|
|
|
|
return jsonify(result)
|
|
except Exception as e:
|
|
logger.error(f"Error in VQA endpoint: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
@app.route('/search', methods=['POST'])
|
|
def search_endpoint():
|
|
try:
|
|
# Get data from request
|
|
data = request.json
|
|
if not data:
|
|
return jsonify({"error": "No data provided"}), 400
|
|
|
|
# Extract image and query
|
|
image_data = data.get('image')
|
|
query = data.get('query')
|
|
|
|
if not image_data or not query:
|
|
return jsonify({"error": "Image and query are required"}), 400
|
|
|
|
# Process the request
|
|
result = process_image_query(image_data, query, task="search")
|
|
|
|
return jsonify(result)
|
|
except Exception as e:
|
|
logger.error(f"Error in search endpoint: {str(e)}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
return jsonify({"status": "ok", "model": VLM_MODEL_NAME})
|
|
|
|
if __name__ == '__main__':
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description='Vision Language Model Backend')
|
|
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to run the server on')
|
|
parser.add_argument('--port', type=int, default=8399, help='Port to run the server on')
|
|
parser.add_argument('--debug', action='store_true', help='Run in debug mode')
|
|
args = parser.parse_args()
|
|
|
|
# Load models
|
|
load_models()
|
|
|
|
# Run the Flask application
|
|
app.run(host=args.host, port=args.port, debug=args.debug)
|