from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import math


class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        # Skip connection (residual connection)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += residual  # Add the residual connection
        out = F.relu(out)  # Final ReLU
        return out


class ChessNet(nn.Module):
    def __init__(self, device=None) -> None:
        super(ChessNet, self).__init__()
        
        # Use passed device or automatically select best available
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device)
        
        # Initial convolution
        self.conv1 = nn.Conv2d(21, 64, kernel_size=3, padding=1)
        
        # First block group - reduced from 6 to 2 blocks
        self.res_block1_1 = ResidualBlock(64, 64)
        self.res_block1_2 = ResidualBlock(64, 64)
        # These blocks are commented out to reduce model size
        # self.res_block1_3 = ResidualBlock(64, 64)
        # self.res_block1_4 = ResidualBlock(64, 64)
        # self.res_block1_5 = ResidualBlock(64, 64)
        # self.res_block1_6 = ResidualBlock(64, 64)
        
        # Second block group - reduced from 6 to 2 blocks
        self.res_block2_1 = ResidualBlock(64, 128, stride=2)  # Keep this for downsampling
        self.res_block2_2 = ResidualBlock(128, 128)
        # These blocks are commented out to reduce model size
        # self.res_block2_3 = ResidualBlock(128, 128)
        # self.res_block2_4 = ResidualBlock(128, 128)
        # self.res_block2_5 = ResidualBlock(128, 128)
        # self.res_block2_6 = ResidualBlock(128, 128)

        # Fully connected and output heads
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.policy_head = nn.Linear(256, 8 * 8 * 73)
        self.value_head = nn.Linear(256, 1)
        
        # Initialize weights with appropriate methods
        self._initialize_weights()
        
        # Move the model to the selected device
        self.to(self.device)
        print(f"ChessNet running on: {self.device}")

    def _initialize_weights(self):
        """
        Initialize network weights using appropriate methods:
        - He/Kaiming initialization for convolutional layers and ReLU activations
        - Xavier/Glorot initialization for fully connected layers
        - Custom small initialization for policy and value heads
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Kaiming/He initialization for Conv layers with ReLU
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                if m == self.policy_head:
                    # Small initialization for policy head to avoid large initial logits
                    nn.init.normal_(m.weight, mean=0.0, std=0.01)
                    nn.init.constant_(m.bias, 0)
                elif m == self.value_head:
                    # Small initialization for value head (outputs to tanh)
                    nn.init.normal_(m.weight, mean=0.0, std=0.01)
                    nn.init.constant_(m.bias, 0)
                else:
                    # Xavier/Glorot for fully connected layers
                    nn.init.xavier_uniform_(m.weight)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor, legal_moves_mask: Optional[torch.Tensor] = None) -> \
            Tuple[torch.Tensor, torch.Tensor]:

        # Initial processing
        x = F.relu(self.conv1(x))
        
        # First residual block group (reduced)
        x = self.res_block1_1(x)
        x = self.res_block1_2(x)
        # Commented out blocks
        # x = self.res_block1_3(x)
        # x = self.res_block1_4(x)
        # x = self.res_block1_5(x)
        # x = self.res_block1_6(x)
        
        # Second residual block group (reduced)
        x = self.res_block2_1(x)
        x = self.res_block2_2(x)
        # Commented out blocks
        # x = self.res_block2_3(x)
        # x = self.res_block2_4(x)
        # x = self.res_block2_5(x)
        # x = self.res_block2_6(x)

        # Output heads
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        policy_logits = self.policy_head(x)
        
        if legal_moves_mask is not None:
            # Apply mask through softmax (numerically stable approach)
            max_logits = torch.max(policy_logits, dim=1, keepdim=True)[0]
            exp_logits = torch.exp(policy_logits - max_logits)
            
            masked_exp_logits = exp_logits * legal_moves_mask
            sum_exp = torch.sum(masked_exp_logits, dim=1, keepdim=True).clamp(min=1e-10)
            policy_probs = masked_exp_logits / sum_exp
            return policy_probs, torch.tanh(self.value_head(x))

        # For raw logits mode
        value = self.value_head(x)
        value = torch.tanh(value)
        return policy_logits, value