thanhkt commited on
Commit
5567b73
·
1 Parent(s): 2daf76b

Upload 9 files

Browse files
llm_config/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/__init__.py
2
+ """
3
+ LLM Configuration system with SOLID architecture.
4
+ """
5
+
6
+ from .interfaces import (
7
+ APIKeyValidationResult,
8
+ LLMConfiguration,
9
+ IProviderValidator,
10
+ IConfigurationManager,
11
+ IProviderManager,
12
+ IUIStateManager,
13
+ INotificationService
14
+ )
15
+
16
+ from .validation import ProviderValidator
17
+ from .configuration import ConfigurationManager
18
+ from .provider_manager import EnhancedProviderManager
19
+ from .ui_manager import UIStateManager
20
+ from .notifications import GradioNotificationService
21
+ from .llm_config_facade import LLMConfigurationFacade
22
+
23
+ __all__ = [
24
+ 'APIKeyValidationResult',
25
+ 'LLMConfiguration',
26
+ 'IProviderValidator',
27
+ 'IConfigurationManager',
28
+ 'IProviderManager',
29
+ 'IUIStateManager',
30
+ 'INotificationService',
31
+ 'ProviderValidator',
32
+ 'ConfigurationManager',
33
+ 'EnhancedProviderManager',
34
+ 'UIStateManager',
35
+ 'GradioNotificationService',
36
+ 'LLMConfigurationFacade'
37
+ ]
llm_config/configuration.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/configuration.py
2
+ """
3
+ Configuration management with persistence and validation.
4
+ """
5
+
6
+ import json
7
+ import os
8
+ import logging
9
+ from typing import Optional
10
+ from .interfaces import IConfigurationManager, LLMConfiguration
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ConfigurationManager(IConfigurationManager):
16
+ """Manages LLM configuration persistence and defaults."""
17
+
18
+ def __init__(self, config_file: str = "llm_config.json"):
19
+ self.config_file = config_file
20
+ self.config_dir = os.path.dirname(os.path.abspath(config_file))
21
+
22
+ # Ensure config directory exists
23
+ if self.config_dir and not os.path.exists(self.config_dir):
24
+ os.makedirs(self.config_dir, exist_ok=True)
25
+
26
+ def save_configuration(self, config: LLMConfiguration) -> bool:
27
+ """Save configuration to persistent storage."""
28
+ try:
29
+ config_dict = {
30
+ 'provider': config.provider,
31
+ 'model': config.model,
32
+ 'api_key': config.api_key,
33
+ 'temperature': config.temperature,
34
+ 'max_retries': config.max_retries,
35
+ 'helper_model': config.helper_model,
36
+ 'use_rag': config.use_rag,
37
+ 'use_visual_fix_code': config.use_visual_fix_code,
38
+ 'use_context_learning': config.use_context_learning,
39
+ 'verbose': config.verbose,
40
+ 'max_scene_concurrency': config.max_scene_concurrency
41
+ }
42
+
43
+ with open(self.config_file, 'w') as f:
44
+ json.dump(config_dict, f, indent=2)
45
+
46
+ logger.info(f"Configuration saved to {self.config_file}")
47
+ return True
48
+
49
+ except Exception as e:
50
+ logger.error(f"Failed to save configuration: {str(e)}")
51
+ return False
52
+
53
+ def load_configuration(self) -> Optional[LLMConfiguration]:
54
+ """Load configuration from persistent storage."""
55
+ try:
56
+ if not os.path.exists(self.config_file):
57
+ logger.info(f"Configuration file {self.config_file} not found")
58
+ return None
59
+
60
+ with open(self.config_file, 'r') as f:
61
+ config_dict = json.load(f)
62
+
63
+ config = LLMConfiguration(
64
+ provider=config_dict.get('provider', 'OpenAI'),
65
+ model=config_dict.get('model', 'gpt-4'),
66
+ api_key=config_dict.get('api_key', ''),
67
+ temperature=config_dict.get('temperature', 0.7),
68
+ max_retries=config_dict.get('max_retries', 3),
69
+ helper_model=config_dict.get('helper_model'),
70
+ use_rag=config_dict.get('use_rag', True),
71
+ use_visual_fix_code=config_dict.get('use_visual_fix_code', False),
72
+ use_context_learning=config_dict.get('use_context_learning', True),
73
+ verbose=config_dict.get('verbose', False),
74
+ max_scene_concurrency=config_dict.get('max_scene_concurrency', 1)
75
+ )
76
+
77
+ logger.info(f"Configuration loaded from {self.config_file}")
78
+ return config
79
+
80
+ except Exception as e:
81
+ logger.error(f"Failed to load configuration: {str(e)}")
82
+ return None
83
+
84
+ def get_default_configuration(self) -> LLMConfiguration:
85
+ """Get default configuration."""
86
+ return LLMConfiguration(
87
+ provider='OpenAI',
88
+ model='gpt-4',
89
+ api_key='',
90
+ temperature=0.7,
91
+ max_retries=3,
92
+ helper_model='gemini/gemini-2.5-flash-preview-04-17',
93
+ use_rag=True,
94
+ use_visual_fix_code=False,
95
+ use_context_learning=True,
96
+ verbose=False,
97
+ max_scene_concurrency=1
98
+ )
99
+
100
+ def backup_configuration(self) -> bool:
101
+ """Create a backup of the current configuration."""
102
+ try:
103
+ if not os.path.exists(self.config_file):
104
+ return True # Nothing to backup
105
+
106
+ backup_file = f"{self.config_file}.backup"
107
+
108
+ with open(self.config_file, 'r') as src:
109
+ with open(backup_file, 'w') as dst:
110
+ dst.write(src.read())
111
+
112
+ logger.info(f"Configuration backed up to {backup_file}")
113
+ return True
114
+
115
+ except Exception as e:
116
+ logger.error(f"Failed to backup configuration: {str(e)}")
117
+ return False
118
+
119
+ def restore_configuration(self) -> Optional[LLMConfiguration]:
120
+ """Restore configuration from backup."""
121
+ try:
122
+ backup_file = f"{self.config_file}.backup"
123
+
124
+ if not os.path.exists(backup_file):
125
+ logger.warning("No backup file found")
126
+ return None
127
+
128
+ # Replace current config with backup
129
+ with open(backup_file, 'r') as src:
130
+ with open(self.config_file, 'w') as dst:
131
+ dst.write(src.read())
132
+
133
+ logger.info("Configuration restored from backup")
134
+ return self.load_configuration()
135
+
136
+ except Exception as e:
137
+ logger.error(f"Failed to restore configuration: {str(e)}")
138
+ return None
llm_config/interfaces.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/interfaces.py
2
+ """
3
+ Abstract interfaces for LLM configuration system following SOLID principles.
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import Dict, List, Optional, Tuple
8
+ from dataclasses import dataclass
9
+
10
+
11
+ @dataclass
12
+ class APIKeyValidationResult:
13
+ """Result of API key validation."""
14
+ is_valid: bool
15
+ error_message: Optional[str] = None
16
+ provider_name: Optional[str] = None
17
+
18
+
19
+ @dataclass
20
+ class LLMConfiguration:
21
+ """Complete LLM configuration."""
22
+ provider: str
23
+ model: str
24
+ api_key: str
25
+ temperature: float = 0.7
26
+ max_retries: int = 3
27
+ helper_model: Optional[str] = None
28
+ use_rag: bool = True
29
+ use_visual_fix_code: bool = False
30
+ use_context_learning: bool = True
31
+ verbose: bool = False
32
+ max_scene_concurrency: int = 1
33
+
34
+
35
+ class IProviderValidator(ABC):
36
+ """Interface for validating provider configurations."""
37
+
38
+ @abstractmethod
39
+ async def validate_api_key(self, provider: str, api_key: str) -> APIKeyValidationResult:
40
+ """Validate an API key for a specific provider."""
41
+ pass
42
+
43
+ @abstractmethod
44
+ def get_supported_providers(self) -> List[str]:
45
+ """Get list of supported providers."""
46
+ pass
47
+
48
+
49
+ class IConfigurationManager(ABC):
50
+ """Interface for managing LLM configurations."""
51
+
52
+ @abstractmethod
53
+ def save_configuration(self, config: LLMConfiguration) -> bool:
54
+ """Save configuration to persistent storage."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ def load_configuration(self) -> Optional[LLMConfiguration]:
59
+ """Load configuration from persistent storage."""
60
+ pass
61
+
62
+ @abstractmethod
63
+ def get_default_configuration(self) -> LLMConfiguration:
64
+ """Get default configuration."""
65
+ pass
66
+
67
+
68
+ class IProviderManager(ABC):
69
+ """Interface for managing providers and models."""
70
+
71
+ @abstractmethod
72
+ def get_providers(self) -> List[str]:
73
+ """Get available providers."""
74
+ pass
75
+
76
+ @abstractmethod
77
+ def get_models(self, provider: str) -> List[str]:
78
+ """Get available models for a provider."""
79
+ pass
80
+
81
+ @abstractmethod
82
+ def get_model_description(self, model: str) -> str:
83
+ """Get description for a model."""
84
+ pass
85
+
86
+ @abstractmethod
87
+ def set_api_key(self, provider: str, api_key: str) -> None:
88
+ """Set API key for a provider."""
89
+ pass
90
+
91
+
92
+ class IUIStateManager(ABC):
93
+ """Interface for managing UI state."""
94
+
95
+ @abstractmethod
96
+ def update_provider_selection(self, provider: str) -> Dict:
97
+ """Update UI when provider is selected."""
98
+ pass
99
+
100
+ @abstractmethod
101
+ def update_model_selection(self, model: str) -> Dict:
102
+ """Update UI when model is selected."""
103
+ pass
104
+
105
+ @abstractmethod
106
+ def show_validation_feedback(self, result: APIKeyValidationResult) -> Dict:
107
+ """Show validation feedback to user."""
108
+ pass
109
+
110
+ @abstractmethod
111
+ def reset_form(self) -> Dict:
112
+ """Reset form to default state."""
113
+ pass
114
+
115
+
116
+ class INotificationService(ABC):
117
+ """Interface for user notifications."""
118
+
119
+ @abstractmethod
120
+ def show_success(self, message: str) -> None:
121
+ """Show success notification."""
122
+ pass
123
+
124
+ @abstractmethod
125
+ def show_error(self, message: str) -> None:
126
+ """Show error notification."""
127
+ pass
128
+
129
+ @abstractmethod
130
+ def show_warning(self, message: str) -> None:
131
+ """Show warning notification."""
132
+ pass
llm_config/llm_config_facade.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/llm_config_facade.py
2
+ """
3
+ Facade pattern implementation for clean LLM configuration management.
4
+ This provides a simple interface to the complex subsystem.
5
+ """
6
+
7
+ import logging
8
+ import asyncio
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ from .interfaces import (
11
+ LLMConfiguration, APIKeyValidationResult,
12
+ IProviderValidator, IConfigurationManager, IProviderManager,
13
+ IUIStateManager, INotificationService
14
+ )
15
+ from .validation import ProviderValidator
16
+ from .configuration import ConfigurationManager
17
+ from .provider_manager import EnhancedProviderManager
18
+ from .ui_manager import UIStateManager
19
+ from .notifications import GradioNotificationService
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class LLMConfigurationFacade:
25
+ """
26
+ Facade for LLM configuration system.
27
+ Provides a simple interface to manage providers, models, and API keys.
28
+ """
29
+
30
+ def __init__(self, config_file: str = "llm_config.json"):
31
+ # Initialize all components
32
+ self.validator: IProviderValidator = ProviderValidator()
33
+ self.config_manager: IConfigurationManager = ConfigurationManager(config_file)
34
+ self.provider_manager: IProviderManager = EnhancedProviderManager()
35
+ self.notification_service: INotificationService = GradioNotificationService()
36
+ self.ui_manager: IUIStateManager = UIStateManager(
37
+ self.provider_manager,
38
+ self.notification_service
39
+ )
40
+
41
+ # Load saved configuration if available
42
+ self._load_saved_configuration()
43
+
44
+ def _load_saved_configuration(self) -> None:
45
+ """Load previously saved configuration."""
46
+ try:
47
+ config = self.config_manager.load_configuration()
48
+ if config:
49
+ # Set API key in provider manager
50
+ if config.api_key:
51
+ self.provider_manager.set_api_key(config.provider, config.api_key)
52
+
53
+ logger.info("Loaded saved configuration")
54
+ else:
55
+ logger.info("No saved configuration found, using defaults")
56
+ except Exception as e:
57
+ logger.error(f"Failed to load saved configuration: {str(e)}")
58
+
59
+ # Provider Management
60
+ def get_providers(self) -> List[str]:
61
+ """Get list of available providers."""
62
+ return self.provider_manager.get_providers()
63
+
64
+ def get_models(self, provider: str) -> List[str]:
65
+ """Get models for a specific provider."""
66
+ return self.provider_manager.get_models(provider)
67
+
68
+ def get_model_description(self, model: str) -> str:
69
+ """Get description for a model."""
70
+ return self.provider_manager.get_model_description(model)
71
+
72
+ def get_provider_description(self, provider: str) -> str:
73
+ """Get description for a provider."""
74
+ return self.provider_manager.get_provider_description(provider)
75
+
76
+ # API Key Management
77
+ def set_api_key(self, provider: str, api_key: str) -> bool:
78
+ """Set API key for a provider."""
79
+ try:
80
+ if not api_key or not api_key.strip():
81
+ self.notification_service.show_error("API key cannot be empty")
82
+ return False
83
+
84
+ self.provider_manager.set_api_key(provider, api_key)
85
+ self.notification_service.show_success(f"API key set for {provider}")
86
+ return True
87
+ except Exception as e:
88
+ logger.error(f"Failed to set API key: {str(e)}")
89
+ self.notification_service.show_error(f"Failed to set API key: {str(e)}")
90
+ return False
91
+
92
+ def validate_api_key(self, provider: str, api_key: str) -> APIKeyValidationResult:
93
+ """Validate an API key synchronously."""
94
+ return self.validator.validate_api_key_sync(provider, api_key)
95
+
96
+ async def validate_api_key_async(self, provider: str, api_key: str) -> APIKeyValidationResult:
97
+ """Validate an API key asynchronously."""
98
+ return await self.validator.validate_api_key(provider, api_key)
99
+
100
+ def has_api_key(self, provider: str) -> bool:
101
+ """Check if provider has an API key set."""
102
+ return self.provider_manager.has_api_key(provider)
103
+
104
+ # Configuration Management
105
+ def save_configuration(self, config: LLMConfiguration) -> bool:
106
+ """Save complete configuration."""
107
+ try:
108
+ success = self.config_manager.save_configuration(config)
109
+ if success:
110
+ self.notification_service.show_success("Configuration saved successfully")
111
+ else:
112
+ self.notification_service.show_error("Failed to save configuration")
113
+ return success
114
+ except Exception as e:
115
+ logger.error(f"Failed to save configuration: {str(e)}")
116
+ self.notification_service.show_error(f"Failed to save configuration: {str(e)}")
117
+ return False
118
+
119
+ def load_configuration(self) -> Optional[LLMConfiguration]:
120
+ """Load saved configuration."""
121
+ return self.config_manager.load_configuration()
122
+
123
+ def get_default_configuration(self) -> LLMConfiguration:
124
+ """Get default configuration."""
125
+ return self.config_manager.get_default_configuration()
126
+
127
+ def create_configuration(self, provider: str, model: str, api_key: str, **kwargs) -> LLMConfiguration:
128
+ """Create a new LLM configuration."""
129
+ return LLMConfiguration(
130
+ provider=provider,
131
+ model=model,
132
+ api_key=api_key,
133
+ temperature=kwargs.get('temperature', 0.7),
134
+ max_retries=kwargs.get('max_retries', 3),
135
+ helper_model=kwargs.get('helper_model'),
136
+ use_rag=kwargs.get('use_rag', True),
137
+ use_visual_fix_code=kwargs.get('use_visual_fix_code', False),
138
+ use_context_learning=kwargs.get('use_context_learning', True),
139
+ verbose=kwargs.get('verbose', False),
140
+ max_scene_concurrency=kwargs.get('max_scene_concurrency', 1)
141
+ )
142
+
143
+ # UI State Management
144
+ def update_provider_selection(self, provider: str) -> Dict:
145
+ """Update UI when provider is selected."""
146
+ return self.ui_manager.update_provider_selection(provider)
147
+
148
+ def update_model_selection(self, model: str) -> Dict:
149
+ """Update UI when model is selected."""
150
+ return self.ui_manager.update_model_selection(model)
151
+
152
+ def show_validation_feedback(self, result: APIKeyValidationResult) -> Dict:
153
+ """Show validation feedback in UI."""
154
+ return self.ui_manager.show_validation_feedback(result)
155
+
156
+ def reset_form(self) -> Dict:
157
+ """Reset form to default state."""
158
+ return self.ui_manager.reset_form()
159
+
160
+ def get_current_ui_state(self) -> Dict[str, Any]:
161
+ """Get current UI state."""
162
+ return self.ui_manager.get_current_configuration()
163
+
164
+ def validate_current_configuration(self) -> Tuple[bool, str]:
165
+ """Validate current configuration."""
166
+ return self.ui_manager.validate_current_configuration()
167
+
168
+ # Utility Methods
169
+ def get_configuration_summary(self) -> Dict:
170
+ """Get summary of current configuration."""
171
+ return self.provider_manager.get_configuration_summary()
172
+
173
+ def test_configuration(self, config: LLMConfiguration) -> Tuple[bool, str]:
174
+ """Test a configuration by validating API key."""
175
+ try:
176
+ result = self.validate_api_key(config.provider, config.api_key)
177
+ return result.is_valid, result.error_message or "Configuration test completed"
178
+ except Exception as e:
179
+ logger.error(f"Failed to test configuration: {str(e)}")
180
+ return False, f"Test failed: {str(e)}"
181
+
182
+ def backup_configuration(self) -> bool:
183
+ """Create a backup of current configuration."""
184
+ return self.config_manager.backup_configuration()
185
+
186
+ def restore_configuration(self) -> Optional[LLMConfiguration]:
187
+ """Restore configuration from backup."""
188
+ config = self.config_manager.restore_configuration()
189
+ if config:
190
+ self.notification_service.show_success("Configuration restored from backup")
191
+ else:
192
+ self.notification_service.show_error("Failed to restore configuration")
193
+ return config
194
+
195
+ def clear_all_api_keys(self) -> None:
196
+ """Clear all stored API keys."""
197
+ try:
198
+ for provider in self.get_providers():
199
+ self.provider_manager.clear_api_key(provider)
200
+ self.notification_service.show_success("All API keys cleared")
201
+ except Exception as e:
202
+ logger.error(f"Failed to clear API keys: {str(e)}")
203
+ self.notification_service.show_error(f"Failed to clear API keys: {str(e)}")
204
+
205
+ def get_last_notification(self) -> Optional[str]:
206
+ """Get last notification message."""
207
+ return self.notification_service.get_last_notification()
208
+
209
+ def initialize_ui_defaults(self) -> Dict:
210
+ """Initialize UI with default values."""
211
+ try:
212
+ providers = self.get_providers()
213
+ if not providers:
214
+ return {}
215
+
216
+ default_provider = providers[0]
217
+ models = self.get_models(default_provider)
218
+ default_model = models[0] if models else None
219
+
220
+ return {
221
+ 'provider_choices': providers,
222
+ 'provider_value': default_provider,
223
+ 'model_choices': models,
224
+ 'model_value': default_model,
225
+ 'model_description': self.get_model_description(default_model) if default_model else "",
226
+ 'has_api_key': self.has_api_key(default_provider)
227
+ }
228
+ except Exception as e:
229
+ logger.error(f"Failed to initialize UI defaults: {str(e)}")
230
+ return {}
llm_config/notifications.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/notifications.py
2
+ """
3
+ Notification service for user feedback.
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ from typing import Optional
9
+ from .interfaces import INotificationService
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class GradioNotificationService(INotificationService):
15
+ """Gradio-specific notification service for user feedback."""
16
+
17
+ def __init__(self):
18
+ self.last_message: Optional[str] = None
19
+ self.last_type: Optional[str] = None
20
+
21
+ def show_success(self, message: str) -> None:
22
+ """Show success notification."""
23
+ self.last_message = f"✅ {message}"
24
+ self.last_type = "success"
25
+ logger.info(f"Success notification: {message}")
26
+ # In a real Gradio app, you might want to trigger a UI update here
27
+
28
+ def show_error(self, message: str) -> None:
29
+ """Show error notification."""
30
+ self.last_message = f"❌ {message}"
31
+ self.last_type = "error"
32
+ logger.error(f"Error notification: {message}")
33
+
34
+ def show_warning(self, message: str) -> None:
35
+ """Show warning notification."""
36
+ self.last_message = f"⚠️ {message}"
37
+ self.last_type = "warning"
38
+ logger.warning(f"Warning notification: {message}")
39
+
40
+ def show_info(self, message: str) -> None:
41
+ """Show info notification."""
42
+ self.last_message = f"ℹ️ {message}"
43
+ self.last_type = "info"
44
+ logger.info(f"Info notification: {message}")
45
+
46
+ def get_last_notification(self) -> Optional[str]:
47
+ """Get the last notification message."""
48
+ return self.last_message
49
+
50
+ def get_last_type(self) -> Optional[str]:
51
+ """Get the type of the last notification."""
52
+ return self.last_type
53
+
54
+ def clear_notifications(self) -> None:
55
+ """Clear all notifications."""
56
+ self.last_message = None
57
+ self.last_type = None
58
+
59
+ def create_gradio_update(self, message: str, notification_type: str = "info") -> dict:
60
+ """Create a Gradio update dict for displaying notifications."""
61
+ icons = {
62
+ "success": "✅",
63
+ "error": "❌",
64
+ "warning": "⚠️",
65
+ "info": "ℹ️"
66
+ }
67
+
68
+ icon = icons.get(notification_type, "ℹ️")
69
+ formatted_message = f"{icon} {message}"
70
+
71
+ # Store the notification
72
+ self.last_message = formatted_message
73
+ self.last_type = notification_type
74
+
75
+ return gr.update(
76
+ value=formatted_message,
77
+ visible=True
78
+ )
79
+
80
+ def format_validation_message(self, is_valid: bool, message: str) -> str:
81
+ """Format a validation message with appropriate styling."""
82
+ if is_valid:
83
+ return f"✅ {message}"
84
+ else:
85
+ return f"❌ {message}"
86
+
87
+ def format_status_message(self, status: str, details: str = "") -> str:
88
+ """Format a status message."""
89
+ if details:
90
+ return f"📊 {status}: {details}"
91
+ else:
92
+ return f"📊 {status}"
93
+
94
+
95
+ class ConsoleNotificationService(INotificationService):
96
+ """Console-based notification service for testing/development."""
97
+
98
+ def show_success(self, message: str) -> None:
99
+ """Show success notification."""
100
+ print(f"✅ SUCCESS: {message}")
101
+ logger.info(f"Success: {message}")
102
+
103
+ def show_error(self, message: str) -> None:
104
+ """Show error notification."""
105
+ print(f"❌ ERROR: {message}")
106
+ logger.error(f"Error: {message}")
107
+
108
+ def show_warning(self, message: str) -> None:
109
+ """Show warning notification."""
110
+ print(f"⚠️ WARNING: {message}")
111
+ logger.warning(f"Warning: {message}")
112
+
113
+ def show_info(self, message: str) -> None:
114
+ """Show info notification."""
115
+ print(f"ℹ️ INFO: {message}")
116
+ logger.info(f"Info: {message}")
llm_config/provider_manager.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/provider_manager.py
2
+ """
3
+ Enhanced provider manager with better separation of concerns.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import Dict, List, Optional
9
+ from .interfaces import IProviderManager
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class EnhancedProviderManager(IProviderManager):
15
+ """Enhanced provider manager with better organization and extensibility."""
16
+
17
+ def __init__(self):
18
+ self.providers_config = {
19
+ 'OpenAI': {
20
+ 'api_key_env': 'OPENAI_API_KEY',
21
+ 'models': [
22
+ 'openai/gpt-4',
23
+ 'openai/gpt-4o',
24
+ 'openai/gpt-3.5-turbo'
25
+ ],
26
+ 'display_name': 'OpenAI',
27
+ 'description': 'Advanced AI models from OpenAI'
28
+ },
29
+ 'Google Gemini': {
30
+ 'api_key_env': 'GOOGLE_API_KEY',
31
+ 'models': [
32
+ 'gemini/gemini-1.5-pro-002',
33
+ 'gemini/gemini-2.5-flash-preview-04-17'
34
+ ],
35
+ 'display_name': 'Google Gemini',
36
+ 'description': 'Google\'s powerful Gemini models'
37
+ },
38
+ 'Anthropic': {
39
+ 'api_key_env': 'ANTHROPIC_API_KEY',
40
+ 'models': [
41
+ 'anthropic/claude-3-5-sonnet-20241022',
42
+ 'anthropic/claude-3-haiku'
43
+ ],
44
+ 'display_name': 'Anthropic Claude',
45
+ 'description': 'Anthropic\'s Claude family of models'
46
+ },
47
+ 'OpenRouter': {
48
+ 'api_key_env': 'OPENROUTER_API_KEY',
49
+ 'models': [
50
+ 'openrouter/openai/gpt-4o',
51
+ 'openrouter/openai/gpt-4o-mini',
52
+ 'openrouter/anthropic/claude-3.5-sonnet',
53
+ 'openrouter/anthropic/claude-3-haiku',
54
+ 'openrouter/google/gemini-pro-1.5',
55
+ 'openrouter/deepseek/deepseek-chat',
56
+ 'openrouter/qwen/qwen-2.5-72b-instruct',
57
+ 'openrouter/meta-llama/llama-3.1-8b-instruct:free',
58
+ 'openrouter/microsoft/phi-3-mini-128k-instruct:free'
59
+ ],
60
+ 'display_name': 'OpenRouter',
61
+ 'description': 'Access multiple models through OpenRouter'
62
+ }
63
+ }
64
+
65
+ self.model_descriptions = {
66
+ "openai/gpt-4": "🎯 Reliable and consistent, great for educational content",
67
+ "openai/gpt-4o": "🚀 Latest OpenAI model with enhanced capabilities",
68
+ "gemini/gemini-1.5-pro-002": "🧠 Advanced reasoning, excellent for complex mathematical concepts",
69
+ "gemini/gemini-2.5-flash-preview-04-17": "⚡ Fast processing, good for quick prototypes",
70
+ "anthropic/claude-3-5-sonnet-20241022": "📚 Excellent at detailed explanations and structured content",
71
+ "anthropic/claude-3-haiku": "💨 Fast and efficient for simpler tasks",
72
+ "openrouter/openai/gpt-4o": "🌐 GPT-4o via OpenRouter - Powerful and versatile",
73
+ "openrouter/openai/gpt-4o-mini": "🌐 GPT-4o Mini via OpenRouter - Fast and cost-effective",
74
+ "openrouter/anthropic/claude-3.5-sonnet": "🌐 Claude 3.5 Sonnet via OpenRouter - Excellent reasoning",
75
+ "openrouter/anthropic/claude-3-haiku": "🌐 Claude 3 Haiku via OpenRouter - Quick responses",
76
+ "openrouter/google/gemini-pro-1.5": "🌐 Gemini Pro 1.5 via OpenRouter - Google's advanced model",
77
+ "openrouter/deepseek/deepseek-chat": "🌐 DeepSeek Chat via OpenRouter - Advanced conversation",
78
+ "openrouter/qwen/qwen-2.5-72b-instruct": "🌐 Qwen 2.5 72B via OpenRouter - Alibaba's flagship model",
79
+ "openrouter/meta-llama/llama-3.1-8b-instruct:free": "🌐 Llama 3.1 8B via OpenRouter - Free open source model",
80
+ "openrouter/microsoft/phi-3-mini-128k-instruct:free": "🌐 Phi-3 Mini via OpenRouter - Free Microsoft model"
81
+ }
82
+
83
+ self.api_keys: Dict[str, str] = {}
84
+ self.selected_provider: Optional[str] = None
85
+ self.selected_model: Optional[str] = None
86
+
87
+ def get_providers(self) -> List[str]:
88
+ """Get available providers."""
89
+ return list(self.providers_config.keys())
90
+
91
+ def get_models(self, provider: str) -> List[str]:
92
+ """Get available models for a provider."""
93
+ if provider not in self.providers_config:
94
+ logger.warning(f"Provider '{provider}' not found")
95
+ return []
96
+
97
+ return self.providers_config[provider].get('models', [])
98
+
99
+ def get_model_description(self, model: str) -> str:
100
+ """Get description for a model."""
101
+ return self.model_descriptions.get(model, "No description available")
102
+
103
+ def get_provider_description(self, provider: str) -> str:
104
+ """Get description for a provider."""
105
+ if provider not in self.providers_config:
106
+ return "Unknown provider"
107
+
108
+ return self.providers_config[provider].get('description', 'No description available')
109
+
110
+ def get_provider_display_name(self, provider: str) -> str:
111
+ """Get display name for a provider."""
112
+ if provider not in self.providers_config:
113
+ return provider
114
+
115
+ return self.providers_config[provider].get('display_name', provider)
116
+
117
+ def set_api_key(self, provider: str, api_key: str) -> None:
118
+ """Set API key for a provider."""
119
+ if provider not in self.providers_config:
120
+ logger.error(f"Cannot set API key for unknown provider: {provider}")
121
+ return
122
+
123
+ env_var = self.providers_config[provider]['api_key_env']
124
+ os.environ[env_var] = api_key
125
+ self.api_keys[provider] = api_key
126
+
127
+ logger.info(f"API key set for provider: {provider}")
128
+
129
+ def get_api_key(self, provider: str) -> Optional[str]:
130
+ """Get API key for a provider."""
131
+ if provider not in self.providers_config:
132
+ return None
133
+
134
+ # First check our internal storage
135
+ if provider in self.api_keys:
136
+ return self.api_keys[provider]
137
+
138
+ # Then check environment variable
139
+ env_var = self.providers_config[provider]['api_key_env']
140
+ return os.environ.get(env_var)
141
+
142
+ def has_api_key(self, provider: str) -> bool:
143
+ """Check if provider has an API key set."""
144
+ api_key = self.get_api_key(provider)
145
+ return api_key is not None and api_key.strip() != ""
146
+
147
+ def get_default_model(self, provider: str) -> Optional[str]:
148
+ """Get the default (first) model for a provider."""
149
+ models = self.get_models(provider)
150
+ return models[0] if models else None
151
+
152
+ def is_valid_provider(self, provider: str) -> bool:
153
+ """Check if provider is valid."""
154
+ return provider in self.providers_config
155
+
156
+ def is_valid_model(self, provider: str, model: str) -> bool:
157
+ """Check if model is valid for the given provider."""
158
+ return model in self.get_models(provider)
159
+
160
+ def get_models_with_descriptions(self, provider: str) -> Dict[str, str]:
161
+ """Get models with their descriptions for a provider."""
162
+ models = self.get_models(provider)
163
+ return {model: self.get_model_description(model) for model in models}
164
+
165
+ def clear_api_key(self, provider: str) -> None:
166
+ """Clear API key for a provider."""
167
+ if provider in self.api_keys:
168
+ del self.api_keys[provider]
169
+
170
+ if provider in self.providers_config:
171
+ env_var = self.providers_config[provider]['api_key_env']
172
+ if env_var in os.environ:
173
+ del os.environ[env_var]
174
+
175
+ logger.info(f"API key cleared for provider: {provider}")
176
+
177
+ def get_configuration_summary(self) -> Dict:
178
+ """Get a summary of current configuration."""
179
+ return {
180
+ 'total_providers': len(self.providers_config),
181
+ 'providers_with_keys': sum(1 for p in self.providers_config if self.has_api_key(p)),
182
+ 'selected_provider': self.selected_provider,
183
+ 'selected_model': self.selected_model,
184
+ 'available_models': sum(len(config['models']) for config in self.providers_config.values())
185
+ }
llm_config/ui_manager.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/ui_manager.py
2
+ """
3
+ UI state management with clean separation of concerns.
4
+ """
5
+
6
+ import gradio as gr
7
+ import logging
8
+ from typing import Dict, List, Optional, Tuple, Any
9
+ from .interfaces import IUIStateManager, APIKeyValidationResult, IProviderManager, INotificationService
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class UIStateManager(IUIStateManager):
15
+ """Manages UI state changes and updates for LLM configuration."""
16
+
17
+ def __init__(self, provider_manager: IProviderManager, notification_service: INotificationService):
18
+ self.provider_manager = provider_manager
19
+ self.notification_service = notification_service
20
+ self.current_provider: Optional[str] = None
21
+ self.current_model: Optional[str] = None
22
+
23
+ def update_provider_selection(self, provider: str) -> Dict:
24
+ """Update UI when provider is selected."""
25
+ try:
26
+ if not self.provider_manager.is_valid_provider(provider):
27
+ logger.warning(f"Invalid provider selected: {provider}")
28
+ return self._create_error_update("Invalid provider selected")
29
+
30
+ self.current_provider = provider
31
+ models = self.provider_manager.get_models(provider)
32
+ default_model = models[0] if models else None
33
+ self.current_model = default_model
34
+
35
+ # Check if API key is already set
36
+ has_key = self.provider_manager.has_api_key(provider)
37
+ api_key_value = self.provider_manager.get_api_key(provider) if has_key else ""
38
+
39
+ # Get provider description
40
+ provider_desc = self.provider_manager.get_provider_description(provider)
41
+
42
+ return {
43
+ 'model_dropdown': gr.update(
44
+ choices=models,
45
+ value=default_model,
46
+ visible=len(models) > 0
47
+ ),
48
+ 'api_key_input': gr.update(
49
+ value=api_key_value,
50
+ placeholder=f"Enter your {provider} API key"
51
+ ),
52
+ 'provider_info': gr.update(
53
+ value=f"**{provider}**: {provider_desc}",
54
+ visible=True
55
+ ),
56
+ 'model_description': gr.update(
57
+ value=self.provider_manager.get_model_description(default_model) if default_model else "",
58
+ visible=default_model is not None
59
+ ),
60
+ 'validation_status': gr.update(
61
+ value="✅ API key found" if has_key else "⚠️ API key required",
62
+ visible=True
63
+ )
64
+ }
65
+
66
+ except Exception as e:
67
+ logger.error(f"Error updating provider selection: {str(e)}")
68
+ return self._create_error_update(f"Error updating provider: {str(e)}")
69
+
70
+ def update_model_selection(self, model: str) -> Dict:
71
+ """Update UI when model is selected."""
72
+ try:
73
+ if not model:
74
+ return {}
75
+
76
+ self.current_model = model
77
+ model_description = self.provider_manager.get_model_description(model)
78
+
79
+ return {
80
+ 'model_description': gr.update(
81
+ value=model_description,
82
+ visible=True
83
+ )
84
+ }
85
+
86
+ except Exception as e:
87
+ logger.error(f"Error updating model selection: {str(e)}")
88
+ return self._create_error_update(f"Error updating model: {str(e)}")
89
+
90
+ def show_validation_feedback(self, result: APIKeyValidationResult) -> Dict:
91
+ """Show validation feedback to user."""
92
+ try:
93
+ if result.is_valid:
94
+ status_text = "✅ API key is valid"
95
+ status_color = "green"
96
+ if result.error_message:
97
+ status_text += f" ({result.error_message})"
98
+ else:
99
+ status_text = f"❌ {result.error_message or 'Invalid API key'}"
100
+ status_color = "red"
101
+
102
+ return {
103
+ 'validation_status': gr.update(
104
+ value=status_text,
105
+ visible=True
106
+ ),
107
+ 'api_key_feedback': gr.update(
108
+ value=status_text,
109
+ visible=True
110
+ )
111
+ }
112
+
113
+ except Exception as e:
114
+ logger.error(f"Error showing validation feedback: {str(e)}")
115
+ return self._create_error_update(f"Error showing feedback: {str(e)}")
116
+
117
+ def reset_form(self) -> Dict:
118
+ """Reset form to default state."""
119
+ try:
120
+ providers = self.provider_manager.get_providers()
121
+ default_provider = providers[0] if providers else None
122
+
123
+ if default_provider:
124
+ models = self.provider_manager.get_models(default_provider)
125
+ default_model = models[0] if models else None
126
+ else:
127
+ models = []
128
+ default_model = None
129
+
130
+ self.current_provider = default_provider
131
+ self.current_model = default_model
132
+
133
+ return {
134
+ 'provider_dropdown': gr.update(
135
+ value=default_provider,
136
+ choices=providers
137
+ ),
138
+ 'model_dropdown': gr.update(
139
+ value=default_model,
140
+ choices=models
141
+ ),
142
+ 'api_key_input': gr.update(
143
+ value="",
144
+ placeholder="Enter your API key"
145
+ ),
146
+ 'temperature_slider': gr.update(value=0.7),
147
+ 'max_retries_slider': gr.update(value=3),
148
+ 'max_scene_concurrency_slider': gr.update(value=1),
149
+ 'use_rag_checkbox': gr.update(value=True),
150
+ 'use_visual_fix_code_checkbox': gr.update(value=False),
151
+ 'use_context_learning_checkbox': gr.update(value=True),
152
+ 'verbose_checkbox': gr.update(value=False),
153
+ 'validation_status': gr.update(
154
+ value="⚠️ Configuration reset",
155
+ visible=True
156
+ ),
157
+ 'model_description': gr.update(
158
+ value=self.provider_manager.get_model_description(default_model) if default_model else "",
159
+ visible=default_model is not None
160
+ )
161
+ }
162
+
163
+ except Exception as e:
164
+ logger.error(f"Error resetting form: {str(e)}")
165
+ return self._create_error_update(f"Error resetting form: {str(e)}")
166
+
167
+ def update_helper_model_selection(self, helper_model: str) -> Dict:
168
+ """Update UI when helper model is selected."""
169
+ try:
170
+ if not helper_model:
171
+ return {}
172
+
173
+ helper_description = self.provider_manager.get_model_description(helper_model)
174
+
175
+ return {
176
+ 'helper_model_description': gr.update(
177
+ value=f"Helper: {helper_description}",
178
+ visible=True
179
+ )
180
+ }
181
+
182
+ except Exception as e:
183
+ logger.error(f"Error updating helper model selection: {str(e)}")
184
+ return {}
185
+
186
+ def get_current_configuration(self) -> Dict[str, Any]:
187
+ """Get current UI configuration state."""
188
+ return {
189
+ 'provider': self.current_provider,
190
+ 'model': self.current_model,
191
+ 'has_api_key': self.provider_manager.has_api_key(self.current_provider) if self.current_provider else False
192
+ }
193
+
194
+ def validate_current_configuration(self) -> Tuple[bool, str]:
195
+ """Validate current configuration."""
196
+ if not self.current_provider:
197
+ return False, "No provider selected"
198
+
199
+ if not self.current_model:
200
+ return False, "No model selected"
201
+
202
+ if not self.provider_manager.has_api_key(self.current_provider):
203
+ return False, "API key not set for selected provider"
204
+
205
+ return True, "Configuration is valid"
206
+
207
+ def _create_error_update(self, error_message: str) -> Dict:
208
+ """Create an error update for UI components."""
209
+ return {
210
+ 'validation_status': gr.update(
211
+ value=f"❌ {error_message}",
212
+ visible=True
213
+ )
214
+ }
215
+
216
+ def show_configuration_summary(self) -> Dict:
217
+ """Show a summary of current configuration."""
218
+ try:
219
+ config = self.get_current_configuration()
220
+ summary_text = f"""
221
+ **Current Configuration:**
222
+ - Provider: {config['provider'] or 'None'}
223
+ - Model: {config['model'] or 'None'}
224
+ - API Key: {'✅ Set' if config['has_api_key'] else '❌ Not set'}
225
+ """
226
+
227
+ return {
228
+ 'configuration_summary': gr.update(
229
+ value=summary_text,
230
+ visible=True
231
+ )
232
+ }
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error showing configuration summary: {str(e)}")
236
+ return {}
llm_config/validation.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_config/validation.py
2
+ """
3
+ Provider validation implementation with proper error handling and async support.
4
+ """
5
+
6
+ import asyncio
7
+ import aiohttp
8
+ import os
9
+ import logging
10
+ from typing import Dict, List, Optional
11
+ from .interfaces import IProviderValidator, APIKeyValidationResult
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ProviderValidator(IProviderValidator):
17
+ """Validates API keys and provider configurations."""
18
+
19
+ def __init__(self):
20
+ self.validation_endpoints = {
21
+ 'OpenAI': {
22
+ 'url': 'https://api.openai.com/v1/models',
23
+ 'headers_fn': lambda key: {'Authorization': f'Bearer {key}'}
24
+ },
25
+ 'Google Gemini': {
26
+ 'url': 'https://generativelanguage.googleapis.com/v1beta/models',
27
+ 'headers_fn': lambda key: {},
28
+ 'params_fn': lambda key: {'key': key}
29
+ },
30
+ 'Anthropic': {
31
+ 'url': 'https://api.anthropic.com/v1/messages',
32
+ 'headers_fn': lambda key: {
33
+ 'x-api-key': key,
34
+ 'anthropic-version': '2023-06-01',
35
+ 'content-type': 'application/json'
36
+ },
37
+ 'method': 'POST',
38
+ 'data': {
39
+ 'model': 'claude-3-haiku-20240307',
40
+ 'max_tokens': 1,
41
+ 'messages': [{'role': 'user', 'content': 'test'}]
42
+ }
43
+ },
44
+ 'OpenRouter': {
45
+ 'url': 'https://openrouter.ai/api/v1/models',
46
+ 'headers_fn': lambda key: {'Authorization': f'Bearer {key}'}
47
+ }
48
+ }
49
+
50
+ # Timeout for validation requests
51
+ self.timeout = 10.0
52
+
53
+ async def validate_api_key(self, provider: str, api_key: str) -> APIKeyValidationResult:
54
+ """Validate an API key for a specific provider."""
55
+ if not api_key or not api_key.strip():
56
+ return APIKeyValidationResult(
57
+ is_valid=False,
58
+ error_message="API key cannot be empty",
59
+ provider_name=provider
60
+ )
61
+
62
+ if provider not in self.validation_endpoints:
63
+ return APIKeyValidationResult(
64
+ is_valid=False,
65
+ error_message=f"Provider '{provider}' is not supported",
66
+ provider_name=provider
67
+ )
68
+
69
+ try:
70
+ endpoint_config = self.validation_endpoints[provider]
71
+ url = endpoint_config['url']
72
+ headers = endpoint_config['headers_fn'](api_key)
73
+ method = endpoint_config.get('method', 'GET')
74
+
75
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
76
+
77
+ async with aiohttp.ClientSession(timeout=timeout) as session:
78
+ kwargs = {'headers': headers}
79
+
80
+ # Add query parameters if needed
81
+ if 'params_fn' in endpoint_config:
82
+ kwargs['params'] = endpoint_config['params_fn'](api_key)
83
+
84
+ # Add data for POST requests
85
+ if method == 'POST' and 'data' in endpoint_config:
86
+ kwargs['json'] = endpoint_config['data']
87
+
88
+ async with session.request(method, url, **kwargs) as response:
89
+ # Consider 200-299 as valid, and some specific error codes as invalid API key
90
+ if response.status < 300:
91
+ logger.info(f"API key validation successful for {provider}")
92
+ return APIKeyValidationResult(
93
+ is_valid=True,
94
+ provider_name=provider
95
+ )
96
+ elif response.status in [401, 403]:
97
+ logger.warning(f"Invalid API key for {provider}: {response.status}")
98
+ return APIKeyValidationResult(
99
+ is_valid=False,
100
+ error_message="Invalid API key - please check your credentials",
101
+ provider_name=provider
102
+ )
103
+ else:
104
+ logger.warning(f"Unexpected response for {provider}: {response.status}")
105
+ # For other errors, we'll assume the key might be valid but service unavailable
106
+ return APIKeyValidationResult(
107
+ is_valid=True, # Assume valid if we can't determine
108
+ error_message=f"Could not verify API key (service returned {response.status})",
109
+ provider_name=provider
110
+ )
111
+
112
+ except asyncio.TimeoutError:
113
+ logger.warning(f"Timeout validating API key for {provider}")
114
+ return APIKeyValidationResult(
115
+ is_valid=True, # Assume valid if timeout
116
+ error_message="Validation timed out - key might be valid",
117
+ provider_name=provider
118
+ )
119
+ except Exception as e:
120
+ logger.error(f"Error validating API key for {provider}: {str(e)}")
121
+ return APIKeyValidationResult(
122
+ is_valid=False,
123
+ error_message=f"Validation error: {str(e)}",
124
+ provider_name=provider
125
+ )
126
+
127
+ def get_supported_providers(self) -> List[str]:
128
+ """Get list of supported providers."""
129
+ return list(self.validation_endpoints.keys())
130
+
131
+ def validate_api_key_sync(self, provider: str, api_key: str) -> APIKeyValidationResult:
132
+ """Synchronous wrapper for API key validation."""
133
+ try:
134
+ loop = asyncio.get_event_loop()
135
+ if loop.is_running():
136
+ # If we're already in an event loop, create a new thread
137
+ import concurrent.futures
138
+ with concurrent.futures.ThreadPoolExecutor() as executor:
139
+ future = executor.submit(
140
+ lambda: asyncio.run(self.validate_api_key(provider, api_key))
141
+ )
142
+ return future.result(timeout=self.timeout + 5)
143
+ else:
144
+ return loop.run_until_complete(self.validate_api_key(provider, api_key))
145
+ except Exception as e:
146
+ logger.error(f"Error in sync validation: {str(e)}")
147
+ return APIKeyValidationResult(
148
+ is_valid=False,
149
+ error_message=f"Validation failed: {str(e)}",
150
+ provider_name=provider
151
+ )
provider.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # provider.py
2
+ """
3
+ Provider management for Theory2Manim Gradio app.
4
+ Allows user to select provider, enter API key, and select model.
5
+ """
6
+
7
+ import os
8
+ from typing import Dict, List, Optional
9
+
10
+ class ProviderManager:
11
+ def __init__(self):
12
+ # Example provider configs; extend as needed
13
+ self.providers = {
14
+ 'OpenAI': {
15
+ 'api_key_env': 'OPENAI_API_KEY',
16
+ 'models': [
17
+ 'gpt-4.1', 'gpt-4o', 'gpt-3.5-turbo'
18
+ ]
19
+ },
20
+ 'Google Gemini': {
21
+ 'api_key_env': 'GOOGLE_API_KEY',
22
+ 'models': [
23
+ 'gemini-1.5-pro-002', 'gemini-2.5-flash-preview-04-17'
24
+ ]
25
+ },
26
+ 'Anthropic': {
27
+ 'api_key_env': 'ANTHROPIC_API_KEY',
28
+ 'models': [
29
+ 'claude-3-5-sonnet-20241022', 'claude-3-haiku'
30
+ ]
31
+ },
32
+ 'OpenRouter': {
33
+ 'api_key_env': 'OPENROUTER_API_KEY',
34
+ 'models': [
35
+ 'openai/gpt-4o', 'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet',
36
+ 'anthropic/claude-3-haiku', 'google/gemini-pro-1.5', 'deepseek/deepseek-chat',
37
+ 'qwen/qwen-2.5-72b-instruct', 'meta-llama/llama-3.1-8b-instruct:free',
38
+ 'microsoft/phi-3-mini-128k-instruct:free'
39
+ ]
40
+ }
41
+ }
42
+ self.selected_provider = None
43
+ self.api_keys = {}
44
+
45
+ def get_providers(self) -> List[str]:
46
+ return list(self.providers.keys())
47
+
48
+ def get_models(self, provider: str) -> List[str]:
49
+ return self.providers.get(provider, {}).get('models', [])
50
+
51
+ def set_api_key(self, provider: str, api_key: str):
52
+ env_var = self.providers[provider]['api_key_env']
53
+ os.environ[env_var] = api_key
54
+ self.api_keys[provider] = api_key
55
+
56
+ def get_api_key(self, provider: str) -> Optional[str]:
57
+ env_var = self.providers[provider]['api_key_env']
58
+ return os.environ.get(env_var)
59
+
60
+ def get_selected_provider(self) -> Optional[str]:
61
+ return self.selected_provider
62
+
63
+ def set_selected_provider(self, provider: str):
64
+ self.selected_provider = provider
65
+
66
+ def get_selected_model(self) -> Optional[str]:
67
+ if self.selected_provider:
68
+ return self.get_models(self.selected_provider)[0]
69
+ return None
70
+
71
+ provider_manager = ProviderManager()