import os import torch from PIL import Image # ------------------------- # 1) Secrets (from environment variables) # ------------------------- # Make sure you set these in your environment or a .env file # Example in Linux/macOS: # export GROQ_API_KEY="your_key" # export HUGGINGFACEHUB_API_TOKEN="your_token" os.environ["HF_TOKEN"] = os.environ.get("HUGGINGFACEHUB_API_TOKEN", "") if not os.environ.get("GROQ_API_KEY"): raise ValueError("❌ Missing GROQ_API_KEY in environment variables") if not os.environ.get("HUGGINGFACEHUB_API_TOKEN"): print("⚠️ HUGGINGFACEHUB_API_TOKEN missing. If a model is gated, it may fail to download.") # ------------------------- # 2) Device config # ------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 print("Device:", device) torch.backends.cuda.matmul.allow_tf32 = True # ------------------------- # 3) LangChain (Groq) # ------------------------- from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0.7) prompt_refiner = ChatPromptTemplate.from_template(""" You are an expert AI prompt engineer for SDXL text-to-image generation. Convert the user's idea into a high-quality image prompt. Rules: - concise (max 60 words) - include subject, environment, lighting, composition, style - avoid brand names, watermarks, copyrighted characters - keep any style the user mentions (anime/realistic/etc.) User idea: {text} Final image prompt: """) caption_refiner = ChatPromptTemplate.from_template(""" You are an expert image caption editor. Rewrite the caption in clear, neutral English (1–2 sentences). No identity guessing. Raw caption: {caption} Final caption: """) prompt_chain = prompt_refiner | llm | StrOutputParser() caption_chain = caption_refiner | llm | StrOutputParser() NEG_DEFAULT = "lowres, blurry, bad anatomy, extra fingers, watermark, text, logo, jpeg artifacts, deformed" # ------------------------- # 4) SDXL pipeline # ------------------------- from diffusers import StableDiffusionXLPipeline MODEL_ID = "playgroundai/playground-v2.5-1024px-aesthetic" pipe = StableDiffusionXLPipeline.from_pretrained( MODEL_ID, torch_dtype=dtype, use_safetensors=True, token=os.environ.get("HUGGINGFACEHUB_API_TOKEN") or None ).to(device) pipe.enable_attention_slicing() try: pipe.enable_vae_tiling() except Exception: pass pipe.safety_checker = None def _gen(seed: int): seed = int(seed) return torch.Generator(device="cuda").manual_seed(seed) if device == "cuda" else torch.Generator().manual_seed(seed) @torch.inference_mode() def text_to_image(user_text, steps=30, guidance=6.5, seed=123, size=1024, negative_prompt=NEG_DEFAULT): if not user_text or not str(user_text).strip(): raise ValueError("Please enter a non-empty prompt.") enhanced = prompt_chain.invoke({"text": user_text}).strip() g = _gen(seed) img = pipe( prompt=enhanced, negative_prompt=negative_prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), height=int(size), width=int(size), generator=g ).images[0] return enhanced, img # ------------------------- # 5) Image → Text (BLIP) # ------------------------- from transformers import pipeline as hf_pipeline caption_model = hf_pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", device=0 if device == "cuda" else -1 ) def image_to_text(img): if img is None: raise ValueError("Please upload an image.") raw = caption_model(img)[0]["generated_text"].strip() refined = caption_chain.invoke({"caption": raw}).strip() return raw, refined # ------------------------- # 6) Gradio App # ------------------------- import gradio as gr with gr.Blocks(title="LangChain Text ↔ Image (SDXL, Secure)") as app: gr.Markdown("## 🔁 LangChain Text ↔ Image (SDXL, Secret Key Based) — Better Quality on T4") with gr.Tab("Text → Image (SDXL)"): txt = gr.Textbox(label="Enter text prompt", placeholder="e.g., A futuristic hospital lab with AI robots, cinematic lighting, ultra-detailed") with gr.Row(): size = gr.Radio([512, 1024], value=1024, label="Resolution (Use 512 if OOM)") seed = gr.Number(value=123, label="Seed") with gr.Row(): steps = gr.Slider(10, 50, value=30, step=1, label="Steps (Quality ↑ with steps)") guidance = gr.Slider(1.0, 10.0, value=6.5, step=0.1, label="Guidance (5–8 best)") negative = gr.Textbox(value=NEG_DEFAULT, label="Negative prompt (quality control)") btn1 = gr.Button("Generate Image") refined_prompt = gr.Textbox(label="Enhanced Prompt (LangChain)", interactive=False) img = gr.Image(label="Generated Image") btn1.click(text_to_image, [txt, steps, guidance, seed, size, negative], [refined_prompt, img]) with gr.Tab("Image → Text"): img_in = gr.Image(type="pil", label="Upload image") btn2 = gr.Button("Generate Caption") raw = gr.Textbox(label="Raw Caption (BLIP)", interactive=False) clean = gr.Textbox(label="Refined Caption (LangChain)", interactive=False) btn2.click(image_to_text, img_in, [raw, clean]) app.launch(share=True)