Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from RealESRGAN import RealESRGAN | |
| from io import BytesIO | |
| # Function to load the model based on scale and anime toggle | |
| def load_model(scale, anime=False): | |
| try: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = RealESRGAN(device, scale=scale, anime=anime) | |
| model_path = { | |
| (2, False): 'model/RealESRGAN_x2.pth', | |
| (4, False): 'model/RealESRGAN_x4plus.pth', | |
| (8, False): 'model/RealESRGAN_x8.pth', | |
| (4, True): 'model/RealESRGAN_x4plus_anime_6B.pth' | |
| }[(scale, anime)] | |
| model.load_weights(model_path) | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load the model: {e}") | |
| return None | |
| def enhance_image(image, scale, anime): | |
| try: | |
| model = load_model(scale, anime=anime) | |
| if model is None: | |
| return None, None | |
| # Convert image to RGB if it has an alpha channel | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| sr_image = model.predict(image) | |
| buffer = BytesIO() | |
| sr_image.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| return sr_image, buffer | |
| except Exception as e: | |
| st.error(f"An error occurred during image enhancement: {e}") | |
| return None, None | |
| def main(): | |
| st.title("Generative AI Image Restoration") | |
| # Image upload | |
| uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) | |
| if uploaded_image is not None: | |
| try: | |
| image = Image.open(uploaded_image) | |
| # Anime toggle | |
| anime = st.checkbox("Anime Image", value=False) | |
| # Conditional scale options | |
| if anime: | |
| scale = "4x" # Set to 4x automatically when anime is selected | |
| else: | |
| scale = st.radio("Upscaling Factor", ["2x", "4x", "8x"], index=0) | |
| scale_value = int(scale.replace('x', '')) | |
| # Enhance button | |
| if st.button("Restore Image"): | |
| enhanced_image, buffer = enhance_image(image, scale_value, anime) | |
| if enhanced_image: | |
| # Show images side by side | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original Image", use_column_width=True) | |
| with col2: | |
| st.image(enhanced_image, caption="Enhanced Image", use_column_width=True) | |
| # Download button | |
| st.download_button( | |
| label="Download Enhanced Image", | |
| data=buffer, | |
| file_name="enhanced_image.png", | |
| mime="image/png" | |
| ) | |
| except Exception as e: | |
| st.error(f"An error occurred while processing the image: {e}") | |
| if __name__ == "__main__": | |
| main() | |