import { GenericPublicClientPKCEProducer } from "@/services/PKCE.js";
import { OAuth2Client } from "oslo/oauth2";
import type {
  AuthStorage,
  Endpoints,
  OIDCTokenResponseBody,
  SessionData,
} from "@/types.js";
import type { AuthConfig } from "@/server/config.js";
import {
  exchangeTokens,
  getEndpointsWithOverrides,
  retrieveTokens,
  storeServerTokens,
  validateOauth2Tokens,
} from "@/shared/lib/util.js";
import type { AuthenticationResolver, PKCEProducer } from "@/services/types.ts";
import { DEFAULT_AUTH_SERVER } from "@/constants.js";

export class ServerAuthenticationResolver implements AuthenticationResolver {
  private pkceProducer: PKCEProducer;
  private oauth2client: OAuth2Client | undefined;
  private endpoints: Endpoints | undefined;

  private constructor(
    readonly authConfig: AuthConfig,
    readonly storage: AuthStorage,
    readonly endpointOverrides?: Partial<Endpoints>,
  ) {
    this.pkceProducer = new GenericPublicClientPKCEProducer(storage);
  }

  /**
   * returns The session data if the session is valid, otherwise an unauthenticated session
   * @returns {Promise<SessionData>}
   */
  async validateExistingSession(): Promise<SessionData> {
    // TODO: investigate a more peformant way to validate a server session
    // other than using JWKS and JWT verification which is what validateOauth2Tokens uses
    const sessionData = await this.getSessionData();
    if (!sessionData?.idToken || !sessionData.accessToken) {
      return { ...sessionData, authenticated: false };
    }
    if (!this.endpoints?.jwks || !this.oauth2client) await this.init();

    if (!this.endpoints?.jwks) {
      throw new Error("JWKS endpoint not found");
    }
    try {
      // this function will throw if any of the tokens are invalid
      await validateOauth2Tokens(
        {
          access_token: sessionData.accessToken,
          id_token: sessionData.idToken,
          refresh_token: sessionData.refreshToken,
          access_token_expires_at: sessionData.accessTokenExpiresAt,
        },
        this.endpoints.jwks,
        this.oauth2client!,
        this.oauthServer,
      );
      return sessionData;
    } catch (error) {
      console.error("Error validating tokens", error);
      return { ...sessionData, authenticated: false };
    }
  }

  get oauthServer(): string {
    return this.authConfig.oauthServer || DEFAULT_AUTH_SERVER;
  }

  async init(): Promise<this> {
    // resolve oauth config
    this.endpoints = await getEndpointsWithOverrides(
      this.oauthServer,
      this.endpointOverrides,
    );
    this.oauth2client = new OAuth2Client(
      this.authConfig.clientId,
      this.endpoints.auth,
      this.endpoints.token,
      {
        redirectURI: this.authConfig.redirectUrl,
      },
    );

    return this;
  }

  async tokenExchange(
    code: string,
    state: string,
  ): Promise<OIDCTokenResponseBody> {
    if (!this.oauth2client) await this.init();
    const codeVerifier = await this.pkceProducer.getCodeVerifier();
    if (!codeVerifier) throw new Error("Code verifier not found in storage");

    // exchange auth code for tokens
    const tokens = await exchangeTokens(
      code,
      state,
      this.pkceProducer,
      this.oauth2client!, // clean up types here to avoid the ! operator
      this.oauthServer,
      this.endpoints!, // clean up types here to avoid the ! operator
    );

    await storeServerTokens(this.storage, tokens);

    return tokens;
  }

  async getSessionData(): Promise<SessionData | null> {
    const storageData = await retrieveTokens(this.storage);

    if (!storageData) return null;

    return {
      authenticated: !!storageData.id_token,
      idToken: storageData.id_token,
      accessToken: storageData.access_token,
      refreshToken: storageData.refresh_token,
    };
  }

  async getEndSessionEndpoint(): Promise<string | null> {
    if (!this.endpoints) {
      return null;
    }
    return this.endpoints.endsession;
  }

  static async build(
    authConfig: AuthConfig,
    storage: AuthStorage,
    endpointOverrides?: Partial<Endpoints>,
  ): Promise<AuthenticationResolver> {
    const resolver = new ServerAuthenticationResolver(
      authConfig,
      storage,
      endpointOverrides,
    );
    await resolver.init();

    return resolver;
  }
}
