// Utility functions shared by auth server and client integrations
// Typically these functions should be used inside AuthenticationInitiator and AuthenticationResolver implementations
import type {
  AuthStorage,
  Endpoints,
  OIDCTokenResponseBody,
  ParsedTokens,
} from "@/types.js";
import {
  AUTH_SERVER_LEGACY_SESSION,
  AUTH_SERVER_SESSION,
  OAuthTokenTypes,
} from "./types.js";
import { OAuth2Client } from "oslo/oauth2";
import { getIssuerVariations, getOauthEndpoints } from "@/lib/oauth.js";
import * as jose from "jose";
import { withoutUndefined } from "@/utils.js";
import type { PKCEConsumer, PKCEProducer } from "@/services/types.js";
import { GenericUserSession } from "@/shared/lib/UserSession.js";
import { decodeJwt, type JWTPayload } from "jose";
import type { CookieStorage } from "./storage.js";
import {
  AUTOREFRESH_TIMEOUT_NAME,
  LOGOUT_STATE,
  REFRESH_IN_PROGRESS,
} from "@/constants.js";

/**
 * Given a PKCE code verifier, derive the code challenge using SHA
 */
export async function deriveCodeChallenge(
  codeVerifier: string,
  method: "Plain" | "S256" = "S256",
): Promise<string> {
  if (method === "Plain") {
    console.warn("Using insecure plain code challenge method");
    return codeVerifier;
  }

  const encoder = new TextEncoder();
  const data = encoder.encode(codeVerifier);
  const digest = await crypto.subtle.digest("SHA-256", data);
  return btoa(String.fromCharCode(...new Uint8Array(digest)))
    .replace(/\+/g, "-")
    .replace(/\//g, "_")
    .replace(/=+$/, "");
}

export async function getEndpointsWithOverrides(
  oauthServer: string,
  endpointOverrides: Partial<Endpoints> = {},
): Promise<Endpoints> {
  const endpoints = await getOauthEndpoints(oauthServer);
  return {
    ...endpoints,
    ...endpointOverrides,
  };
}

export async function generateOauthLoginUrl(config: {
  clientId: string;
  scopes: string[];
  state: string;
  redirectUrl: string;
  oauthServer: string;
  nonce?: string;
  endpointOverrides?: Partial<Endpoints>;
  // used to get the PKCE challenge
  pkceConsumer: PKCEConsumer;
}): Promise<URL> {
  const endpoints = await getEndpointsWithOverrides(
    config.oauthServer,
    config.endpointOverrides,
  );
  const oauth2Client = buildOauth2Client(
    config.clientId,
    config.redirectUrl,
    endpoints,
  );
  const challenge = await config.pkceConsumer.getCodeChallenge();
  const oAuthUrl = await oauth2Client.createAuthorizationURL({
    state: config.state,
    scopes: config.scopes,
  });
  // The OAuth2 client supports PKCE, but does not allow passing in a code challenge from some other source
  // It only allows passing in a code verifier which it then hashes itself.
  oAuthUrl.searchParams.append("code_challenge", challenge);
  oAuthUrl.searchParams.append("code_challenge_method", "S256");
  if (config.nonce) {
    // nonce isn't supported by oslo, so we add it manually
    oAuthUrl.searchParams.append("nonce", config.nonce);
  }
  // Required by the auth server for offline_access scope
  oAuthUrl.searchParams.append("prompt", "consent");

  return oAuthUrl;
}

export async function generateOauthLogoutUrl(config: {
  clientId: string;
  redirectUrl: string;
  idToken: string;
  state: string;
  oauthServer: string;
  endpointOverrides?: Partial<Endpoints>;
}): Promise<URL> {
  const endpoints = await getEndpointsWithOverrides(
    config.oauthServer,
    config.endpointOverrides,
  );
  const endSessionUrl = new URL(endpoints.endsession);
  endSessionUrl.searchParams.append("client_id", config.clientId);
  endSessionUrl.searchParams.append("id_token_hint", config.idToken);
  endSessionUrl.searchParams.append("state", config.state);
  endSessionUrl.searchParams.append(
    "post_logout_redirect_uri",
    config.redirectUrl,
  );
  return endSessionUrl;
}

export function buildOauth2Client(
  clientId: string,
  redirectUri: string,
  endpoints: Endpoints,
): OAuth2Client {
  return new OAuth2Client(clientId, endpoints.auth, endpoints.token, {
    redirectURI: redirectUri,
  });
}

export async function exchangeTokens(
  code: string,
  state: string,
  pkceProducer: PKCEProducer,
  oauth2Client: OAuth2Client,
  oauthServer: string,
  endpoints: Endpoints,
) {
  const codeVerifier = await pkceProducer.getCodeVerifier();
  if (!codeVerifier) throw new Error("Code verifier not found in state");

  const tokens =
    await oauth2Client.validateAuthorizationCode<OIDCTokenResponseBody>(code, {
      codeVerifier,
    });

  // Validate relevant tokens
  try {
    await validateOauth2Tokens(
      tokens,
      endpoints.jwks,
      oauth2Client,
      oauthServer,
    );
  } catch (error) {
    console.error("tokenExchange error", { error, tokens });
    throw new Error(
      `OIDC tokens validation failed: ${(error as Error).message}`,
    );
  }
  return tokens;
}

export const getAccessTokenExpiresAt = (
  tokens: OIDCTokenResponseBody,
): number => {
  const parsedAccessToken = decodeJwt(tokens.access_token);
  if (parsedAccessToken?.exp || false) {
    return parsedAccessToken.exp;
  } else if (tokens.expires_in) {
    const now = Math.floor(new Date().getTime() / 1000);
    return now + tokens.expires_in;
  } else {
    throw new Error("Cannot determine access token expiry!");
  }
};
export async function setAccessTokenExpiresAt(
  storage: AuthStorage | CookieStorage,
  tokens: OIDCTokenResponseBody,
) {
  // try to extract absolute expiry time from access token but fallback to calculation if not possible
  const accessTokenExpiresAt = getAccessTokenExpiresAt(tokens);
  await storage.set(
    OAuthTokenTypes.ACCESS_TOKEN_EXPIRES_AT,
    accessTokenExpiresAt.toString(),
  );
}

export async function storeTokens(
  storage: AuthStorage,
  tokens: OIDCTokenResponseBody,
) {
  await storage.set(OAuthTokenTypes.ID_TOKEN, tokens.id_token);
  await storage.set(OAuthTokenTypes.ACCESS_TOKEN, tokens.access_token);
  if (tokens.refresh_token) {
    await storage.set(OAuthTokenTypes.REFRESH_TOKEN, tokens.refresh_token);
  }
  await setAccessTokenExpiresAt(storage, tokens);
}

export async function storeServerTokens(
  storage: AuthStorage | CookieStorage,
  tokens: OIDCTokenResponseBody,
) {
  const accessTokenExpiresAt = getAccessTokenExpiresAt(tokens);
  const cookieStorage = storage as CookieStorage;
  const now = Math.floor(Date.now() / 1000);
  const accessTokenMaxAge = accessTokenExpiresAt && accessTokenExpiresAt - now;
  const cookiesOverride = {
    ...(accessTokenMaxAge ? { maxAge: accessTokenMaxAge } : {}),
  };
  // the refresh token must be longer-lived than the access token max age to allow time for automatic refresh
  // as it's not a JWT, we derive it from the access token max age and add a margin
  const refreshTokenMaxAge = accessTokenMaxAge && accessTokenMaxAge + 5 * 60;
  const refreshCookiesOverride = {
    ...(refreshTokenMaxAge ? { maxAge: refreshTokenMaxAge } : {}),
  };
  console.log("storeServerTokens overrides", {
    cookiesOverride,
    refreshCookiesOverride,
    now,
    accessTokenMaxAge,
    accessTokenExpiresAt,
  });
  const idTokenExpiry = decodeJwt(tokens.id_token)?.exp;
  const idTokenMaxAge = idTokenExpiry && idTokenExpiry - now;
  await cookieStorage.set(OAuthTokenTypes.ID_TOKEN, tokens.id_token, {
    ...(idTokenMaxAge ? { maxAge: idTokenMaxAge } : {}),
  });
  await cookieStorage.set(
    OAuthTokenTypes.ACCESS_TOKEN,
    tokens.access_token,
    cookiesOverride,
  );
  if (tokens.refresh_token) {
    await cookieStorage.set(
      OAuthTokenTypes.REFRESH_TOKEN,
      tokens.refresh_token,
      refreshCookiesOverride,
    );
  }
  await storage.set(
    OAuthTokenTypes.ACCESS_TOKEN_EXPIRES_AT,
    accessTokenExpiresAt.toString(),
    cookiesOverride,
  );
}

export async function clearTokens(storage: AuthStorage) {
  // clear all local storage keys related to OAuth and CivicAuth SDK
  const clearOAuthPromises = [
    ...Object.values(OAuthTokenTypes),
    REFRESH_IN_PROGRESS,
    AUTOREFRESH_TIMEOUT_NAME,
    LOGOUT_STATE,
  ].map(async (key) => {
    await storage.delete(key);
  });
  await Promise.all([...clearOAuthPromises]);
}

export async function clearAuthServerSession(storage: AuthStorage) {
  await storage.delete(AUTH_SERVER_SESSION);
  await storage.delete(AUTH_SERVER_LEGACY_SESSION);
}

export async function clearUser(storage: AuthStorage) {
  const userSession = new GenericUserSession(storage);
  await userSession.clear();
}

export async function retrieveTokens(
  storage: AuthStorage,
): Promise<OIDCTokenResponseBody | null> {
  const idToken = await storage.get(OAuthTokenTypes.ID_TOKEN);
  const accessToken = await storage.get(OAuthTokenTypes.ACCESS_TOKEN);
  const refreshToken = await storage.get(OAuthTokenTypes.REFRESH_TOKEN);
  const accessTokenExpiresAt = await storage.get(
    OAuthTokenTypes.ACCESS_TOKEN_EXPIRES_AT,
  );
  if (!idToken || !accessToken) return null;

  return {
    id_token: idToken,
    access_token: accessToken,
    refresh_token: refreshToken ?? undefined,
    access_token_expires_at:
      accessTokenExpiresAt !== null
        ? parseInt(accessTokenExpiresAt, 10)
        : undefined, // Convert string to number
  };
}

export async function retrieveAccessTokenExpiresAt(
  storage: AuthStorage,
): Promise<number> {
  return Number(await storage.get(OAuthTokenTypes.ACCESS_TOKEN_EXPIRES_AT));
}

// Single JWKS instance that persists for the lifetime of the SDK session
let cachedJWKS: ReturnType<typeof jose.createRemoteJWKSet> | null = null;
let cachedJwksUrl: string | null = null;

export async function validateOauth2Tokens(
  tokens: OIDCTokenResponseBody,
  jwksEndpoint: string,
  oauth2Client: OAuth2Client,
  issuer: string,
): Promise<ParsedTokens> {
  // Only create a new JWKS instance if one doesn't exist yet
  if (!cachedJWKS || cachedJwksUrl !== jwksEndpoint) {
    cachedJWKS = jose.createRemoteJWKSet(new URL(jwksEndpoint));
    cachedJwksUrl = jwksEndpoint;
  }

  // validate the ID token
  const idTokenResponse = await jose.jwtVerify<JWTPayload>(
    tokens.id_token,
    cachedJWKS,
    {
      issuer: getIssuerVariations(issuer),
      audience: oauth2Client.clientId,
    },
  );

  // validate the access token
  const accessTokenResponse = await jose.jwtVerify<JWTPayload>(
    tokens.access_token,
    cachedJWKS,
    {
      issuer: getIssuerVariations(issuer),
    },
  );

  return withoutUndefined({
    id_token: idTokenResponse.payload,
    access_token: accessTokenResponse.payload,
    refresh_token: tokens.refresh_token,
  });
}
