File size: 7,472 Bytes
58ad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847c22c
58ad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
d6caf9e
58ad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e950cec
58ad1a9
 
 
 
 
 
d6caf9e
58ad1a9
 
 
d6caf9e
58ad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e950cec
58ad1a9
 
ebae902
 
58ad1a9
ebae902
 
 
58ad1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import base64
import requests
from typing import List, Tuple, Optional
from PIL import Image
from io import BytesIO

try:
    from google import genai
except ImportError:
    genai = None

try:
    from openai import OpenAI
    import gradio as gr
except ImportError:
    OpenAI = None
    gr = None # Keep gr dependency localized for error handling

# Import internal utility (if needed, though base64 encoding logic is placed inline)
# from utils import image_to_base664 


# --- API Clients Initialization ---
GEMINI_CLIENT = None
OPENAI_CLIENT = None

if genai:
    try:
        # Client initialization relies on GEMINI_API_KEY environment variable
        GEMINI_CLIENT = genai.Client()
    except Exception:
        # Fail silently if key is missing, handle error in function call
        pass

if OpenAI:
    try:
        # Client initialization relies on OPENAI_API_KEY environment variable
        OPENAI_CLIENT = OpenAI()
    except Exception:
        # Fail silently if key is missing, handle error in function call
        pass

def get_generation_prompt(
    model_choice: str,
    prompt: str,
    image_paths: List[str]
) -> str:
    """Analyzes the images and prompt using the selected multimodal model to generate a detailed prompt for DALL-E 3."""
    
    print(f"--- Analyzing inputs using {model_choice} ---")
    
    analysis_prompt = (
        f"You are an expert creative director. Based on the {len(image_paths)} input images and their text prompt, "
        f"synthesize a new, single, extremely detailed, aesthetic, and descriptive prompt (max 500 characters) "
        f"suitable for a cutting-edge text-to-image generator like DALL-E 3. "
        f"The resulting image must be a 'remix' or fusion incorporating key visual, thematic, and stylistic elements "
        f"from all available images, guided by the text prompt: '{prompt}'."
        f"Focus on composition, lighting, style, mood, and texture. Do not mention 'image', 'remix', or 'input images' in the output. "
        f"Output ONLY the final descriptive prompt text, nothing else."
    )
    
    # Load images as PIL objects
    images = [Image.open(path) for path in image_paths if path]
    
    # --- GEMINI Analysis Path ---
    if model_choice == 'gemini-2':
        if not GEMINI_CLIENT:
            return f"Gemini API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}."
        try:
            # Contents should be images first, then the text prompt
            contents = images + [analysis_prompt]
            response = GEMINI_CLIENT.models.generate_content(
                model='gemini-2.0-flash-live',
                contents=contents
            )
            expanded_prompt = response.text.strip()
            print(f"Gemini Analysis Output: {expanded_prompt}")
            return expanded_prompt
        except Exception as e:
            print(f"Gemini API Error: {e}")
            return f"Error using Gemini for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}."

    # --- GPT Analysis Path ---
    elif model_choice == 'gpt image-1':
        if not OPENAI_CLIENT:
            return f"OpenAI API Key missing or client failed to initialize. Fallback prompt: Fusion of provided visual elements inspired by the prompt: {prompt}."
        try:
            # Prepare contents for gpt-image-1-low with base64 encoded images
            contents = [
                {"type": "text", "text": analysis_prompt}
            ]
            for img in images:
                buffered = BytesIO()
                # Use JPEG to reduce payload size
                img.save(buffered, format="JPEG")
                img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
                contents.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{img_base64}",
                        "detail": "low" # Use low detail for speed
                    }
                })

            response = OPENAI_CLIENT.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "user", "content": contents}
                ],
                max_tokens=500
            )
            expanded_prompt = response.choices[0].message.content.strip()
            print(f"gpt-image-1-low Analysis Output: {expanded_prompt}")
            return expanded_prompt
        except Exception as e:
            print(f"GPT API Error: {e}")
            return f"Error using gpt-image-1-low for analysis. Fallback prompt: Creative fusion of the three elements provided, inspired by the theme: {prompt}."
    
    # Fallback if model is unrecognized
    return f"Creative synthesis of the visual elements provided, inspired by the prompt: {prompt}. Ensure photorealistic quality."

def generate_remixed_image(
    model_choice: str,
    prompt: str,
    image1_path: Optional[str],
    image2_path: Optional[str],
    image3_path: Optional[str]
) -> Tuple[str, Image.Image | None]:
    """Orchestrates prompt generation (via selected model) and image synthesis (via DALL-E 3)."""
    
    image_paths = [image1_path, image2_path, image3_path]
    valid_paths = [path for path in image_paths if path is not None]

    if not OPENAI_CLIENT:
        # Raise generic Gradio error if client is missing
        if gr:
            raise gr.Error("OpenAI client not initialized. Please set OPENAI_API_KEY environment variable.")
        else:
            raise ValueError("OpenAI client not initialized.")
    
    if not valid_paths:
        if gr:
            raise gr.Error("Please upload at least one image to remix.")
        else:
            raise ValueError("No images provided.")
        
    # 1. Generate the optimized DALL-E 3 prompt using the selected analysis model
    final_prompt = get_generation_prompt(model_choice, prompt, valid_paths)
    
    print(f"\n--- Final Prompt for DALL-E 3: {final_prompt} ---")

    # 2. Generate the image using DALL-E 3 (OpenAI API)
    try:
        dalle_response = OPENAI_CLIENT.images.generate(
            model="gpt-image-1",
            prompt=final_prompt,
            size="1024x1024",
            quality="medium",  # valid: low | medium | high | auto
            n=1,
        )
        b64 = dalle_response.data[0].b64_json
        img_bytes = base64.b64decode(b64)
        remixed_image = Image.open(BytesIO(img_bytes)).convert("RGB")
        return final_prompt, remixed_image
        
    except Exception as e:
        print(f"DALL-E 3 Generation Error: {e}")
        error_msg = f"Image generation failed: {str(e)}"
        
        if gr:
            # Create a placeholder error image for display
            placeholder_img = Image.new('RGB', (1024, 1024), color = 'darkred')
            from PIL import ImageDraw, ImageFont
            d = ImageDraw.Draw(placeholder_img)
            
            try:
                font = ImageFont.truetype("arial.ttf", 40)
            except IOError:
                font = ImageFont.load_default()
            
            d.text((50, 450), "GENERATION FAILED", fill=(255, 255, 255), font=font)
            d.text((50, 550), f"Error: {error_msg}", fill=(255, 200, 200), font=font)
            return f"FAILED. Error: {error_msg}", placeholder_img
        
        raise ValueError(error_msg)