import customtkinter
import tkinter as tk
from tkinter import font

class E621_window(customtkinter.CTkToplevel):
    def __init__(self, app, e621_csv, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.title("e621 Tag Generation")
        self.geometry("800x540")

        self.v_large_font = customtkinter.CTkFont(family='Pretendard', size=13)

        self.e621_csv = e621_csv
        self.tag_dict = self.process_csv_data()
        # 검색 최적화를 위한 인덱스 생성
        self.eng_index = self.create_search_index(is_korean=False)
        self.kor_index = self.create_search_index(is_korean=True)
        # 최소 검색 길이 설정
        self.min_search_length = 2

        # Configure grid columns with weights
        self.grid_columnconfigure(0, weight=1)  # Left panel
        self.grid_columnconfigure(1, weight=1)  # Right panel
        
        # Create frames for left and right panels
        self.left_frame = customtkinter.CTkFrame(self, fg_color="#242424")
        self.left_frame.grid(row=0, column=0, padx=10, pady=10, sticky="n")
        
        self.right_frame = customtkinter.CTkFrame(self)
        self.right_frame.grid(row=0, column=1, padx=10, pady=10, sticky="nsew")

        # Initialize parameters dictionary
        self.parameters = {
            "freq_penalty_factor": 0.9,
            "freq_penalty_offset": 0.1,
            "word_weight": 0.1,
            "tag_weight": 0.9,
            "negative_weight": 1.5,
            "word_importance_offset": 0.04,
            "rare_word_bonus": 2.5,
            "freq_threshold_percentile": 2,
            "n_recommendations": 20,
            "randomness": 0.3,
            "max_depth": 10
        }

        self.entries = {}  # Store entry widgets
        label0 = customtkinter.CTkLabel(self.left_frame, font=self.v_large_font, text="               하이퍼파라미터 설정")
        label0.grid(row=0, column=0, padx=5, pady=2, sticky="n")
        # Create left panel elements
        for i, (param, value) in enumerate(self.parameters.items()):
            # Label
            label = customtkinter.CTkLabel(self.left_frame, text=param, font=self.v_large_font)
            label.grid(row=i+1, column=0, padx=5, pady=2, sticky="e")
            
            # Entry
            entry = customtkinter.CTkEntry(self.left_frame, width=60, font=self.v_large_font)
            entry.insert(0, str(value))
            entry.grid(row=i+1, column=1, padx=5, pady=2, sticky="w")
            
            self.entries[param] = entry

        self.result_text = customtkinter.CTkTextbox(self, height=120, font=self.v_large_font)
        self.result_text.grid(row=1, column=0, columnspan=2, padx=5, pady=5, sticky="nsew")
        self.result_text.insert("0.0", "즉시 생성 버튼을 누르면 결과가 여기에 나타납니다... ")
        self.result_text.configure(state="disabled")

        # Create right panel elements
        # Row 0: Label for tag input
        input_label = customtkinter.CTkLabel(self.right_frame, text="e621 기반 태그 입력 (즉시 생성에만 적용)", font=self.v_large_font)
        input_label.grid(row=0, column=0, padx=5, pady=5, sticky="w")

        # Row 1: Text input box
        self.text_input = customtkinter.CTkTextbox(self.right_frame, height=100, font=self.v_large_font)
        self.text_input.grid(row=1, column=0, padx=5, pady=5, sticky="nsew")

        # Row 2: Label for negative tag input
        negative_label = customtkinter.CTkLabel(self.right_frame, text="e621 네거티브 태그 입력", font=self.v_large_font)
        negative_label.grid(row=2, column=0, padx=5, pady=5, sticky="w")

        # Row 3: Negative input box
        self.negative_input = customtkinter.CTkTextbox(self.right_frame, height=100, font=self.v_large_font)
        self.negative_input.grid(row=3, column=0, padx=5, pady=5, sticky="nsew")

        # Row 4: Generate button
        self.e621_generate = customtkinter.CTkButton(
            self.right_frame, 
            text="즉시 생성", 
            font=self.v_large_font,
            command=self.generate_callback
        )
        self.e621_generate.grid(row=4, column=0, padx=5, pady=10, sticky="ew")

        # Row 5: Apply and close button
        self.apply = customtkinter.CTkButton(
            self.right_frame, 
            fg_color="grey10",
            hover_color="grey",
            text="적용 후 닫기 (네거티브 태그만 적용)", 
            font=self.v_large_font,
            command=self.apply_callback
        )
        self.apply.grid(row=5, column=0, padx=5, pady=10, sticky="ew")

        # Configure grid weights for right frame
        self.right_frame.grid_columnconfigure(0, weight=1)
        for i in range(6):
            self.right_frame.grid_rowconfigure(i, weight=0)
        self.right_frame.grid_rowconfigure(1, weight=1)  # Make text input expandable
        self.right_frame.grid_rowconfigure(3, weight=1)  # Make negative input expandable
        self.attributes('-topmost', True)

        # 자동완성 관련 변수들
        self.last_text_prompt = []
        self.last_negative_prompt = []
        self.last_text_timer = None
        self.last_negative_timer = None
        
        # 팝업창 초기화 (text_input용)
        self.popup_text = customtkinter.CTkToplevel(self)
        self.popup_text.withdraw()
        self.popup_text.overrideredirect(True)
        self.listbox_text = tk.Listbox(self.popup_text, font = font.Font(family='Pretendard', size=13), bg='#2B2B2B', fg='#F8F8F8', borderwidth=2, highlightbackground='lightgrey', width=50, height=11)
        self.listbox_text.grid(row=0, column=0, sticky="nsew")
        
        # 팝업창 초기화 (negative_input용)
        self.popup_negative = customtkinter.CTkToplevel(self)
        self.popup_negative.withdraw()
        self.popup_negative.overrideredirect(True)
        self.listbox_negative = tk.Listbox(self.popup_negative, font = font.Font(family='Pretendard', size=13), bg='#2B2B2B', fg='#F8F8F8', borderwidth=2, highlightbackground='lightgrey', width=50, height=11)
        self.listbox_negative.grid(row=0, column=0, sticky="nsew")

        # 타겟 인덱스와 마지막 요소 저장 변수
        self.text_target_index = None
        self.negative_target_index = None
        self.text_last_element = ""
        self.negative_last_element = ""

        # 바인딩 설정
        self.setup_bindings()

        self.protocol("WM_DELETE_WINDOW", self.on_close)
        self.after(1500, lambda: self.attributes('-topmost', False))

        # 3초 후에 초기 프롬프트 저장
        self.after(3000, self.init_prompts)

    def on_close(self):
        self.withdraw()

    def generate_callback(self):
        current_state = self.get_current_state()
        if hasattr(self.master, 'process_e621_request'):
            results = self.master.process_e621_request(current_state)
            
        if results:
            # 텍스트박스 상태 변경 및 내용 초기화
            self.result_text.configure(state="normal")
            self.result_text.delete("0.0", "end")
            
            # 입력 텍스트 + 결과 합치기
            output_text = current_state['input_text']
            for tag, _ in results:
                output_text += f', {tag}'
                
            # 원본 텍스트 삽입
            self.result_text.insert("end", output_text)
            
            # 번역 처리 및 삽입
            translated_text = ''
            for tag in output_text.split(','):
                tag = tag.strip()
                if tag:
                    kor_tag = self.tag_dict.get(tag, tag)
                    translated_text += f', {kor_tag}' if translated_text else kor_tag
            
            self.result_text.insert("end", "\n\n" + translated_text)
            
            # 텍스트박스 잠금
            self.result_text.configure(state="disabled")

    def apply_callback(self):
        # 현재 상태 수집
        current_state = self.get_current_state()
        
        # 부모 클래스에 apply_e621_request가 있는지 확인하고 호출
        if hasattr(self.master, 'apply_e621_request'):
            self.master.apply_e621_request(current_state, self.tag_dict)
        
        # Close the window
        self.withdraw()

    def process_csv_data(self):
        """CSV 데이터에서 영어/한글 태그 딕셔너리 생성"""
        tag_dict = {}
        for _, row in self.e621_csv.iterrows():
            eng_tag = str(row.iloc[0]).strip().lower()  # 소문자로 저장
            kor_tag = str(row.iloc[1]).strip()
            if eng_tag and kor_tag:  # 둘 다 비어있지 않은 경우만
                tag_dict[eng_tag] = kor_tag
        return tag_dict

    def create_search_index(self, is_korean=False):
        """검색 최적화를 위한 인덱스 생성"""
        index = {}
        for eng_tag, kor_tag in self.tag_dict.items():
            target_tag = kor_tag if is_korean else eng_tag
            if not target_tag:
                continue
                
            # 2글자 단위로 인덱싱
            for i in range(len(target_tag) - 1):
                bigram = target_tag[i:i+2]
                if bigram not in index:
                    index[bigram] = set()
                index[bigram].add(eng_tag)
        return index

    def find_matching_tags(self, query: str, max_results: int = 50) -> list:
        """입력된 쿼리와 매칭되는 태그 찾기 (최적화 버전)"""
        query = query.strip().lower()
        if len(query) < self.min_search_length:
            return []

        # 한글 포함 여부 확인
        is_korean = any(ord('가') <= ord(char) <= ord('힣') for char in query)
        
        # 검색 인덱스 선택
        search_index = self.kor_index if is_korean else self.eng_index
        
        # 초기 후보 집합 생성
        candidate_tags = None
        for i in range(len(query) - 1):
            bigram = query[i:i+2]
            if bigram in search_index:
                if candidate_tags is None:
                    candidate_tags = search_index[bigram].copy()
                else:
                    candidate_tags &= search_index[bigram]
                    
                # 조기 종료: 후보가 max_results 이하면 검색 종료
                if len(candidate_tags) <= max_results:
                    break
        
        if not candidate_tags:
            return []

        # 최종 매칭 및 정렬
        matches = []
        for eng_tag in candidate_tags:
            kor_tag = self.tag_dict[eng_tag]
            search_target = kor_tag if is_korean else eng_tag
            
            if query in search_target:
                matches.append((eng_tag, kor_tag))
                
            # 충분한 결과를 찾았다면 종료
            if len(matches) >= max_results:
                break

        # 정확도 기반 정렬 (시작 부분 일치 우선)
        matches.sort(key=lambda x: (
            0 if (x[1 if is_korean else 0].lower().startswith(query)) else 1,
            len(x[0])
        ))

        return matches[:max_results]

    def init_prompts(self):
        self.last_text_prompt = [key for key in self.text_input.get("0.0", "end-1c").split(',')]
        self.last_negative_prompt = [key for key in self.negative_input.get("0.0", "end-1c").split(',')]

    def setup_bindings(self):
        # Text input 바인딩
        self.text_input.bind('<KeyRelease>', lambda e: self.key_yield(e, 'text'))
        self.text_input.bind("<FocusOut>", lambda e: self.popup_focus_out(e, 'text'))
        self.listbox_text.bind('<<ListboxSelect>>', lambda e: self.hard_insertion(e, 'text'))

        # Negative input 바인딩
        self.negative_input.bind('<KeyRelease>', lambda e: self.key_yield(e, 'negative'))
        self.negative_input.bind("<FocusOut>", lambda e: self.popup_focus_out(e, 'negative'))
        self.listbox_negative.bind('<<ListboxSelect>>', lambda e: self.hard_insertion(e, 'negative'))

    def popup_focus_out(self, event, input_type):
        if input_type == 'text':
            self.popup_text.withdraw()
        else:
            self.popup_negative.withdraw()

    def hard_insertion(self, event, input_type):
        try:
            popup = self.popup_text if input_type == 'text' else self.popup_negative
            listbox = self.listbox_text if input_type == 'text' else self.listbox_negative
            text_widget = self.text_input if input_type == 'text' else self.negative_input
            last_prompt = self.last_text_prompt if input_type == 'text' else self.last_negative_prompt
            target_index = self.text_target_index if input_type == 'text' else self.negative_target_index
            selection = listbox.curselection()

            selected_item = listbox.get(selection[0])
            if not selected_item:
                return

            # "(translation)" 부분 제거하고 메인 태그만 추출
            keyword = selected_item.split(" (")[0]
            
            popup.withdraw()
            
            # 프롬프트 업데이트
            if target_index is not None and target_index < len(last_prompt):
                last_prompt[target_index] = keyword
                last_prompt = [key.strip() for key in last_prompt]
                
                # 텍스트 위젯 업데이트
                text_widget.delete("0.0", "end")
                text_widget.insert("0.0", ', '.join(last_prompt))
            
            # 타겟 초기화
            if input_type == 'text':
                self.text_target_index = None
                self.text_last_element = ""
            else:
                self.negative_target_index = None
                self.negative_last_element = ""

        except Exception as e:
            print(f"Error in hard_insertion: {e}")

    def delayed_key_yield(self, input_type):
        try:
            # 입력이 최소 길이 미만이면 팝업 숨김
            text_widget = self.text_input if input_type == 'text' else self.negative_input
            current_text = text_widget.get("0.0", "end-1c").split(',')[-1].strip()
            
            if len(current_text) < self.min_search_length:
                if input_type == 'text':
                    self.popup_text.withdraw()
                else:
                    self.popup_negative.withdraw()
                return
            
            popup = self.popup_text if input_type == 'text' else self.popup_negative
            listbox = self.listbox_text if input_type == 'text' else self.listbox_negative
            last_prompt = self.last_text_prompt if input_type == 'text' else self.last_negative_prompt

            # 리스트박스 초기화
            listbox.delete(0, 'end')
            
            # 현재 프롬프트 가져오기
            current_prompt = [key for key in text_widget.get("0.0", "end-1c").split(',')]
            
            # 변경된 요소 찾기
            if len(current_prompt) != len(last_prompt):
                changed_element = current_prompt[-1] if current_prompt else ""
                changed_index = len(current_prompt) - 1
            else:
                changed_indices = [i for i, (old, new) in enumerate(zip(last_prompt, current_prompt)) if old != new]
                if not changed_indices:
                    popup.withdraw()
                    return
                changed_index = changed_indices[-1]
                changed_element = current_prompt[changed_index]

            if changed_element.strip():
                if input_type == 'text':
                    self.text_last_element = changed_element
                    self.text_target_index = changed_index
                else:
                    self.negative_last_element = changed_element
                    self.negative_target_index = changed_index

                # 매칭되는 태그 찾기
                matching_tags = self.find_matching_tags(changed_element.strip())
                
                if matching_tags:
                    for eng_tag, kor_tag in matching_tags:
                        display_text = f"{eng_tag} ({kor_tag})"
                        listbox.insert('end', display_text)
                    
                    # 팝업 위치 설정
                    x = text_widget.winfo_rootx()
                    y = text_widget.winfo_rooty() + text_widget.winfo_height()
                    popup.geometry(f"+{x}+{y}")
                    popup.deiconify()
                else:
                    popup.withdraw()

            if input_type == 'text':
                self.last_text_prompt = current_prompt
            else:
                self.last_negative_prompt = current_prompt

        except Exception as e:
            print(f"Error in delayed_key_yield: {e}")

    def key_yield(self, event, input_type):
        timer = self.last_text_timer if input_type == 'text' else self.last_negative_timer
        
        if timer:
            self.after_cancel(timer)
        
        new_timer = self.after(200, lambda: self.delayed_key_yield(input_type))
        
        if input_type == 'text':
            self.last_text_timer = new_timer
        else:
            self.last_negative_timer = new_timer

    def get_current_state(self):
        """현재 파라미터와 입력값들을 반환"""
        parameters = {}
        for param, entry in self.entries.items():
            try:
                value = float(entry.get()) if '.' in entry.get() else int(entry.get())
                parameters[param] = value
            except ValueError:
                print(f"Invalid value for {param}")
        
        return {
            'parameters': parameters,
            'input_text': self.text_input.get("0.0", "end-1c"),
            'negative_text': self.negative_input.get("0.0", "end-1c")
        }
    
    