import pandas as pd
import numpy as np
from collections import defaultdict
from typing import List, Dict, Tuple, Set
import random
import re
import os
import pickle

class WordLevelTagRecommender:
    def __init__(self, 
                 df,
                 cache_file: str = "tag_recommender_cache.pkl",
                 freq_penalty_factor: float = 0.67,
                 freq_penalty_offset: float = 0.15,
                 word_weight: float = 0.2,
                 tag_weight: float = 0.8,
                 negative_weight: float = 1.5,
                 word_importance_offset: float = 0.1,
                 rare_word_bonus: float = 1.0,
                 freq_threshold_percentile: int = 1):
        
        self.df = df
        self.cache_file = cache_file
        
        self.hyperparams = {
            'freq_penalty_factor': freq_penalty_factor,
            'freq_penalty_offset': freq_penalty_offset,
            'word_weight': word_weight,
            'tag_weight': tag_weight,
            'negative_weight': negative_weight,
            'word_importance_offset': word_importance_offset,
            'rare_word_bonus': rare_word_bonus,
            'freq_threshold_percentile': freq_threshold_percentile
        }
        
        self.word_to_prompts = defaultdict(set)
        self.tag_frequency = defaultdict(lambda: defaultdict(float))
        self.tag_weights = defaultdict(float)
        self.word_importance = defaultdict(float)
        self.word_frequency = defaultdict(int)
        
        self.stopwords = {'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from',
                         'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on', 'that', 'the',
                         'to', 'was', 'were', 'will', 'with', 'the', 'this', 'but', 'they'}
        
        self.ignored_tags = {'female', 'mammal', 'genitals', 'breasts', 'male', 
                           'hair', 'anthro', 'humanoid', 'fur'}
        
        if not self.should_rebuild_cache():
            self.load_cache()
        else:
            self.build_frequency_matrices()
            self.save_cache()

    def should_rebuild_cache(self) -> bool:
        if not os.path.exists(self.cache_file):
            return True
            
        with open(self.cache_file, 'rb') as f:
            try:
                cache_data = pickle.load(f)
                cached_params = cache_data.get('hyperparams', {})
                
                # 캐시된 하이퍼파라미터와 현재 하이퍼파라미터 비교
                for key, value in self.hyperparams.items():
                    if key not in cached_params or abs(cached_params[key] - value) > 1e-6:
                        return True
                return False
            except:
                return True
    
    def save_cache(self):
        cache_data = {
            'word_to_prompts': dict(self.word_to_prompts),
            'tag_frequency': dict(self.tag_frequency),
            'tag_weights': dict(self.tag_weights),
            'word_importance': dict(self.word_importance),
            'word_frequency': dict(self.word_frequency),
            'hyperparams': self.hyperparams
        }
        with open(self.cache_file, 'wb') as f:
            pickle.dump(cache_data, f)
    
    def load_cache(self):
        with open(self.cache_file, 'rb') as f:
            cache_data = pickle.load(f)
        
        self.word_to_prompts = defaultdict(set, cache_data['word_to_prompts'])
        self.tag_frequency = defaultdict(lambda: defaultdict(float), cache_data['tag_frequency'])
        self.tag_weights = defaultdict(float, cache_data['tag_weights'])
        self.word_importance = defaultdict(float, cache_data['word_importance'])
        self.word_frequency = defaultdict(int, cache_data['word_frequency'])

    def tokenize(self, text: str) -> List[str]:
        text = re.sub(r'[^\w\s]', ' ', text.lower())
        return text.split()
    
    def build_frequency_matrices(self):
        total_prompts = len(self.df)
        
        for _, row in self.df.iterrows():
            prompt = str(row.iloc[0])
            words = self.tokenize(prompt)
            for word in words:
                if word not in self.stopwords:
                    self.word_frequency[word] += 1
        
        word_freq_sorted = sorted(self.word_frequency.items(), key=lambda x: x[1], reverse=True)
        percentile_idx = max(1, len(word_freq_sorted) * self.hyperparams['freq_threshold_percentile'] // 100)
        freq_threshold = word_freq_sorted[percentile_idx][1]
        
        for _, row in self.df.iterrows():
            prompt = str(row.iloc[0])
            words = self.tokenize(prompt)
            tags = str(row.iloc[2]).split(', ')
            counts = list(map(int, str(row.iloc[3]).split(', ')))
            
            for word in set(words):
                if word not in self.stopwords:
                    self.word_to_prompts[word].add(prompt)
            
            total_count = sum(counts)
            if total_count > 0:
                for tag, count in zip(tags, counts):
                    self.tag_frequency[prompt][tag] = count / total_count
                    self.tag_weights[tag] += count
        
        for word, prompts in self.word_to_prompts.items():
            idf = np.log(total_prompts / (len(prompts) + 1))
            frequency_penalty = self.hyperparams['rare_word_bonus']
            if self.word_frequency[word] > freq_threshold:
                frequency_penalty = (self.hyperparams['freq_penalty_factor'] * 
                                  np.log(freq_threshold / self.word_frequency[word] + 
                                       self.hyperparams['freq_penalty_offset']))
            self.word_importance[word] = idf * frequency_penalty
        
        total_weight = sum(self.tag_weights.values())
        if total_weight > 0:
            for tag in self.tag_weights:
                self.tag_weights[tag] /= total_weight

    def find_related_prompts(self, words: List[str]) -> Dict[str, float]:
        prompt_scores = defaultdict(float)
        
        for word in words:
            if word in self.word_to_prompts:
                word_weight = 1.0 / (self.word_importance[word] + 
                                   self.hyperparams['word_importance_offset'])
                for prompt in self.word_to_prompts[word]:
                    prompt_scores[prompt] += word_weight
        
        if prompt_scores:
            max_score = max(prompt_scores.values())
            for prompt in prompt_scores:
                prompt_scores[prompt] /= max_score
                
        return prompt_scores

    def calculate_tag_scores(self, input_text: str, existing_tags: str = "",
                           negative_tags: str = "", randomness: float = 0.2) -> Dict[str, float]:
        word_prompt_scores = self.find_related_prompts(self.tokenize(input_text))
        tag_prompt_scores = self.find_related_prompts([tag.strip() for tag in input_text.split(',')])
        
        negative_tag_list = [tag.strip() for tag in negative_tags.split(',')] if negative_tags else []
        negative_prompt_scores = self.find_related_prompts(negative_tag_list)
        
        existing_tag_set = set(tag.strip() for tag in existing_tags.split(',')) if existing_tags else set()
        negative_tag_set = set(tag.strip() for tag in negative_tags.split(',')) if negative_tags else set()
        
        tag_scores = defaultdict(float)
        negative_influence = defaultdict(float)
        
        for prompt, prompt_score in negative_prompt_scores.items():
            for tag, freq in self.tag_frequency[prompt].items():
                if tag not in existing_tag_set and tag not in self.ignored_tags:
                    negative_influence[tag] += (freq * self.tag_weights[tag] * prompt_score * 
                                             self.hyperparams['negative_weight'])
        
        for prompt, prompt_score in word_prompt_scores.items():
            for tag, freq in self.tag_frequency[prompt].items():
                if tag not in existing_tag_set and tag not in self.ignored_tags:
                    base_score = freq * self.tag_weights[tag] * prompt_score
                    random_factor = 1.0 + random.uniform(-randomness, randomness)
                    tag_scores[tag] += (self.hyperparams['word_weight'] * 
                                      base_score * random_factor)
        
        for prompt, prompt_score in tag_prompt_scores.items():
            for tag, freq in self.tag_frequency[prompt].items():
                if tag not in existing_tag_set and tag not in self.ignored_tags:
                    base_score = freq * self.tag_weights[tag] * prompt_score
                    random_factor = 1.0 + random.uniform(-randomness, randomness)
                    tag_scores[tag] += (self.hyperparams['tag_weight'] * 
                                      base_score * random_factor)
        
        for tag in tag_scores:
            if tag in negative_tag_set:
                tag_scores[tag] = -1
            elif negative_influence[tag] > 0:
                reduction_factor = 1.0 / (1.0 + negative_influence[tag])
                tag_scores[tag] *= reduction_factor
        
        return tag_scores

    def recommend_tags_recursive(self,
                            input_text: str,
                            existing_tags: str = "",
                            negative_tags: str = "",
                            n_recommendations: int = 10,
                            randomness: float = 0.2,
                            depth: int = 0,
                            max_depth: int = 10,
                            initial_input: str = None) -> List[Tuple[str, float]]:
        """
        재귀적 태그 추천
        """
        # 첫 호출시 initial_input 설정
        if initial_input is None:
            initial_input = input_text
            
        # 기본 태그 점수 계산 (네거티브 태그 포함)
        tag_scores = self.calculate_tag_scores(input_text, existing_tags, negative_tags, randomness)
        
        # 음수 점수 필터링 후 정렬
        filtered_tag_scores = {tag: score for tag, score in tag_scores.items() if score >= 0}
        sorted_tags = sorted(filtered_tag_scores.items(), key=lambda x: x[1], reverse=True)
        
        # 최대 깊이 도달 시 최종 정리 후 종료
        if depth >= max_depth:
            initial_tags = set(tag.strip() for tag in initial_input.split(','))
            current_tags = set(tag.strip() for tag in existing_tags.split(',')) if existing_tags else set()
            added_tags = current_tags - initial_tags
            
            if added_tags:
                # 추가된 태그들 중 가장 낮은 점수의 태그 찾기
                added_tag_scores = {tag: score for tag, score in filtered_tag_scores.items() 
                                if tag in added_tags}
                if added_tag_scores:
                    # 가장 낮은 점수의 태그 제거
                    lowest_tag = min(added_tag_scores.items(), key=lambda x: x[1])[0]
                    new_tags = current_tags - {lowest_tag}
                    
                    # 남은 전체 태그 중 가장 높은 점수의 새로운 태그 찾기
                    remaining_tag_scores = {tag: score for tag, score in sorted_tags 
                                        if tag not in new_tags and tag not in negative_tags.split(',')}
                    if remaining_tag_scores:
                        # 가장 높은 점수의 태그 추가
                        highest_tag, _ = next(iter(remaining_tag_scores.items()))
                        new_tags.add(highest_tag)
                    
                    # 입력 텍스트와 existing_tags 업데이트
                    input_text = ', '.join(list(initial_tags) + 
                                    [tag for tag in new_tags if tag not in initial_tags])
                    existing_tags = ', '.join(new_tags)
                    
                    # 최종 점수 재계산
                    final_tag_scores = self.calculate_tag_scores(input_text, existing_tags, 
                                                            negative_tags, randomness)
                    filtered_final_scores = {tag: score for tag, score in final_tag_scores.items() 
                                        if score >= 0}
                    sorted_final_tags = sorted(filtered_final_scores.items(), 
                                            key=lambda x: x[1], reverse=True)
                    return sorted_final_tags[:n_recommendations]
            
            return sorted_tags[:n_recommendations]
        
        # depth가 3으로 나누어 2가 남을 때 태그 제거
        if depth % 3 == 2:
            initial_tags = set(tag.strip() for tag in initial_input.split(','))
            current_tags = set(tag.strip() for tag in existing_tags.split(',')) if existing_tags else set()
            added_tags = current_tags - initial_tags
            
            if added_tags:
                added_tag_scores = {tag: score for tag, score in filtered_tag_scores.items() 
                                if tag in added_tags}
                if added_tag_scores:
                    lowest_tag = min(added_tag_scores.items(), key=lambda x: x[1])[0]
                    new_tags = current_tags - {lowest_tag}
                    input_text = ', '.join(list(initial_tags) + 
                                    [tag for tag in added_tags if tag != lowest_tag])
                    existing_tags = ', '.join(new_tags)
        
        # 상위 50%와 하위 50% 구분
        mid_point = max(1, len(sorted_tags) // 2)
        top_half = sorted_tags[:mid_point]
        bottom_half = sorted_tags[mid_point:]
        
        selected_tags = []
        if top_half:
            selected_tag = random.choice(top_half)
            if selected_tag[0] not in existing_tags:
                selected_tags.append(selected_tag[0])
        if bottom_half:
            selected_tag = random.choice(bottom_half)
            if selected_tag[0] not in existing_tags:
                selected_tags.append(selected_tag[0])
        
        if selected_tags:
            new_tags = set(tag.strip() for tag in existing_tags.split(',')) if existing_tags else set()
            new_tags.update(selected_tags)
            new_input = input_text + ', ' + ', '.join(selected_tags)
            
            return self.recommend_tags_recursive(
                input_text=new_input,
                existing_tags=', '.join(new_tags),
                negative_tags=negative_tags,
                n_recommendations=n_recommendations,
                randomness=randomness,
                depth=depth + 1,
                max_depth=max_depth,
                initial_input=initial_input
            )
        
        return sorted_tags[:n_recommendations]

    def recommend_tags(self,
                    input_text: str,
                    existing_tags: str = "",
                    negative_tags: str = "",
                    n_recommendations: int = 10,
                    randomness: float = 0.2) -> List[Tuple[str, float]]:
        """
        태그 추천 메인 메소드 (재귀적 추천 호출)
        """
        return self.recommend_tags_recursive(
            input_text=input_text,
            existing_tags=existing_tags,
            negative_tags=negative_tags,
            n_recommendations=n_recommendations,
            randomness=randomness,
            initial_input=input_text
        )

    def update_parameters(self, parameters):
        """파라미터 업데이트"""
        for param, value in parameters.items():
            setattr(self, param, value)
        
        # 필요한 경우 frequency matrices 재계산
        self.build_frequency_matrices()

    def recommend_tags(self, 
                    input_text, 
                    negative_tags="", 
                    existing_tags="",
                    n_recommendations=10,
                    randomness=0.2):
        """태그 추천 실행"""
        recommendations = self.recommend_tags_recursive(
            input_text=input_text,
            existing_tags=input_text,
            negative_tags=negative_tags,
            n_recommendations=n_recommendations,
            randomness=randomness,
            initial_input=input_text
        )
        
        return recommendations