// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import {
  CloudWatchLogsClient,
  CreateLogGroupCommand,
  CreateLogStreamCommand,
  DescribeLogGroupsCommand,
  DescribeLogStreamsCommand,
  GetLogEventsCommand,
  PutLogEventsCommand,
} from "@aws-sdk/client-cloudwatch-logs";
import { fetchAuthSession } from "aws-amplify/auth";
import { ConsoleLogger } from "aws-amplify/utils";

import type { LoggingProvider } from "@aws-amplify/core/dist/esm/Logger/types";
import type { AWSCredentials } from "@aws-amplify/core/internals/utils";
import type {
  CreateLogGroupCommandInput,
  CreateLogGroupCommandOutput,
  CreateLogStreamCommandInput,
  CreateLogStreamCommandOutput,
  DescribeLogGroupsCommandInput,
  DescribeLogGroupsCommandOutput,
  DescribeLogStreamsCommandInput,
  DescribeLogStreamsCommandOutput,
  GetLogEventsCommandInput,
  GetLogEventsCommandOutput,
  InputLogEvent,
  LogGroup,
  LogStream,
  PutLogEventsCommandInput,
  PutLogEventsCommandOutput,
} from "@aws-sdk/client-cloudwatch-logs";

// Logging constants
const AWS_CLOUDWATCH_BASE_BUFFER_SIZE = 26;
const AWS_CLOUDWATCH_MAX_BATCH_EVENT_SIZE = 1048576;
const AWS_CLOUDWATCH_MAX_EVENT_SIZE = 256000;
const AWS_CLOUDWATCH_CATEGORY = "Logging";
const AWS_CLOUDWATCH_PROVIDER_NAME = "AWSCloudWatch";
const NO_CREDS_ERROR_STRING = "No credentials";
const RETRY_ERROR_CODES = [
  "ResourceNotFoundException",
  "InvalidSequenceTokenException",
];

interface AWSCloudWatchProviderOptions {
  logGroupName?: string;
  logStreamName?: string;
  region?: string;
  credentials?: AWSCredentials;
  endpoint?: string;
}

interface CloudWatchDataTracker {
  eventUploadInProgress: boolean;
  logEvents: InputLogEvent[];
  verifiedLogGroup?: LogGroup;
}

const logger = new ConsoleLogger("AWSCloudWatch");

class AWSCloudWatchProvider implements LoggingProvider {
  static readonly PROVIDER_NAME = AWS_CLOUDWATCH_PROVIDER_NAME;
  static readonly CATEGORY = AWS_CLOUDWATCH_CATEGORY;

  private _config: AWSCloudWatchProviderOptions;
  private _dataTracker: CloudWatchDataTracker;
  private _currentLogBatch: InputLogEvent[];
  private _timer;
  private _nextSequenceToken: string | undefined;

  constructor(config?: AWSCloudWatchProviderOptions) {
    this.configure(config);
    this._dataTracker = {
      eventUploadInProgress: false,
      logEvents: [],
    };
    this._currentLogBatch = [];
    this._initiateLogPushInterval();
  }

  public getProviderName(): string {
    return AWSCloudWatchProvider.PROVIDER_NAME;
  }

  public getCategoryName(): string {
    return AWSCloudWatchProvider.CATEGORY;
  }

  public getLogQueue(): InputLogEvent[] {
    return this._dataTracker.logEvents;
  }

  public configure(
    config?: AWSCloudWatchProviderOptions,
  ): AWSCloudWatchProviderOptions {
    if (!config) return this._config || {};
  }

  public async createLogGroup(
    params: CreateLogGroupCommandInput,
  ): Promise<CreateLogGroupCommandOutput> {
    logger.debug(
      "creating new log group in CloudWatch - ",
      params.logGroupName,
    );
    const cmd = new CreateLogGroupCommand(params);

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);
      return output;
    } catch (error) {
      logger.error(`error creating log group - ${error}`);
      throw error;
    }
  }

  public async getLogGroups(
    params: DescribeLogGroupsCommandInput,
  ): Promise<DescribeLogGroupsCommandOutput> {
    logger.debug("getting list of log groups");

    const cmd = new DescribeLogGroupsCommand(params);

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);
      return output;
    } catch (error) {
      logger.error(`error getting log group - ${error}`);
      throw error;
    }
  }

  public async createLogStream(
    params: CreateLogStreamCommandInput,
  ): Promise<CreateLogStreamCommandOutput> {
    logger.debug(
      "creating new log stream in CloudWatch - ",
      params.logStreamName,
    );
    const cmd = new CreateLogStreamCommand(params);

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);
      return output;
    } catch (error) {
      logger.error(`error creating log stream - ${error}`);
      throw error;
    }
  }

  public async getLogStreams(
    params: DescribeLogStreamsCommandInput,
  ): Promise<DescribeLogStreamsCommandOutput> {
    logger.debug("getting list of log streams");
    const cmd = new DescribeLogStreamsCommand(params);

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);
      return output;
    } catch (error) {
      logger.error(`error getting log stream - ${error}`);
      throw error;
    }
  }

  public async getLogEvents(
    params: GetLogEventsCommandInput,
  ): Promise<GetLogEventsCommandOutput> {
    logger.debug("getting log events from stream - ", params.logStreamName);
    const cmd = new GetLogEventsCommand(params);

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);
      return output;
    } catch (error) {
      logger.error(`error getting log events - ${error}`);
      throw error;
    }
  }

  public pushLogs(logs: InputLogEvent[]): void {
    logger.debug("pushing log events to Cloudwatch...");
    this._dataTracker.logEvents = [...this._dataTracker.logEvents, ...logs];
  }

  private async _validateLogGroupExistsAndCreate(
    logGroupName: string,
  ): Promise<LogGroup> {
    if (this._dataTracker.verifiedLogGroup) {
      return this._dataTracker.verifiedLogGroup;
    }

    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const currGroups = await this.getLogGroups({
        logGroupNamePrefix: logGroupName,
      });

      if (!(typeof currGroups === "string") && currGroups.logGroups) {
        const foundGroups = currGroups.logGroups.filter(
          (group) => group.logGroupName === logGroupName,
        );
        if (foundGroups.length > 0) {
          this._dataTracker.verifiedLogGroup = foundGroups[0];

          return foundGroups[0];
        }
      }

      /**
       * If we get to this point, it means that the specified log group does not exist
       * and we should create it.
       */
      await this.createLogGroup({ logGroupName });

      return null;
    } catch (err) {
      const errString = `failure during log group search: ${err}`;
      logger.error(errString);
      throw err;
    }
  }

  private async _validateLogStreamExists(
    logGroupName: string,
    logStreamName: string,
  ): Promise<LogStream> {
    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      const currStreams = await this.getLogStreams({
        logGroupName,
        logStreamNamePrefix: logStreamName,
      });

      if (currStreams.logStreams) {
        const foundStreams = currStreams.logStreams.filter(
          (stream) => stream.logStreamName === logStreamName,
        );
        if (foundStreams.length > 0) {
          this._nextSequenceToken = foundStreams[0].uploadSequenceToken;

          return foundStreams[0];
        }
      }

      /**
       * If we get to this point, it means that the specified stream does not
       * exist, and we should create it now.
       */
      await this.createLogStream({
        logGroupName,
        logStreamName,
      });

      return null;
    } catch (err) {
      const errString = `failure during log stream search: ${err}`;
      logger.error(errString);
      throw err;
    }
  }

  private async _sendLogEvents(
    params: PutLogEventsCommandInput,
  ): Promise<PutLogEventsCommandOutput> {
    try {
      const credentialsOK = await this._ensureCredentials();
      if (!credentialsOK) {
        throw new Error(NO_CREDS_ERROR_STRING);
      }

      logger.debug("sending log events to stream - ", params.logStreamName);
      const cmd = new PutLogEventsCommand(params);
      const client = this._initCloudWatchLogs();
      const output = await client.send(cmd);

      return output;
    } catch (err) {
      const errString = `failure during log push: ${err}`;
      logger.error(errString);
    }
  }

  private _initCloudWatchLogs() {
    return new CloudWatchLogsClient({
      region: "eu-west-2",
      credentials: this._config.credentials,
      customUserAgent: "aws-amplify/5.0.4",
      endpoint: this._config.endpoint,
    });
  }

  private async _ensureCredentials() {
    return await fetchAuthSession()
      .then((authSession) => {
        if (!authSession) return false;
        this._config.credentials = authSession.credentials;

        return true;
      })
      .catch((error) => {
        logger.warn("ensure credentials error", error);
        return false;
      });
  }

  private async _getNextSequenceToken(): Promise<string> {
    if (this._nextSequenceToken && this._nextSequenceToken.length > 0) {
      return this._nextSequenceToken;
    }

    /**
     * A sequence token will not exist if any of the following are true:
     *   ...the log group does not exist
     *   ...the log stream does not exist
     *   ...the log stream does exist but has no logs written to it yet
     */
    try {
      await this._validateLogGroupExistsAndCreate(this._config.logGroupName);

      this._nextSequenceToken = undefined;

      const logStream = await this._validateLogStreamExists(
        this._config.logGroupName,
        this._config.logStreamName,
      );

      if (logStream) {
        this._nextSequenceToken = logStream.uploadSequenceToken;
      }

      return this._nextSequenceToken;
    } catch (err) {
      logger.error(`failure while getting next sequence token: ${err}`);
      throw err;
    }
  }

  private async _safeUploadLogEvents(): Promise<PutLogEventsCommandOutput> {
    try {
      /**
       * CloudWatch has restrictions on the size of the log events that get sent up.
       * We need to track both the size of each event and the total size of the batch
       * of logs.
       *
       * We also need to ensure that the logs in the batch are sorted in chronological order.
       * https://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html
       */
      const seqToken = await this._getNextSequenceToken();
      const logBatch =
        this._currentLogBatch.length === 0
          ? this._getBufferedBatchOfLogs()
          : this._currentLogBatch;

      const putLogsPayload: PutLogEventsCommandInput = {
        logGroupName: this._config.logGroupName,
        logStreamName: this._config.logStreamName,
        logEvents: logBatch,
        sequenceToken: seqToken,
      };

      this._dataTracker.eventUploadInProgress = true;
      const sendLogEventsResponse = await this._sendLogEvents(putLogsPayload);

      this._nextSequenceToken = sendLogEventsResponse.nextSequenceToken;
      this._dataTracker.eventUploadInProgress = false;
      this._currentLogBatch = [];

      return sendLogEventsResponse;
    } catch (err) {
      logger.error(`error during _safeUploadLogEvents: ${err}`);

      if (RETRY_ERROR_CODES.includes(err.name)) {
        this._getNewSequenceTokenAndSubmit({
          logEvents: this._currentLogBatch,
          logGroupName: this._config.logGroupName,
          logStreamName: this._config.logStreamName,
        });
      } else {
        this._dataTracker.eventUploadInProgress = false;
        throw err;
      }
    }
  }

  private _getBufferedBatchOfLogs(): InputLogEvent[] {
    /**
     * CloudWatch has restrictions on the size of the log events that get sent up.
     * We need to track both the size of each event and the total size of the batch
     * of logs.
     *
     * We also need to ensure that the logs in the batch are sorted in chronological order.
     * https://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html
     */
    let currentEventIdx = 0;
    let totalByteSize = 0;

    while (currentEventIdx < this._dataTracker.logEvents.length) {
      const currentEvent = this._dataTracker.logEvents[currentEventIdx];
      const eventSize = currentEvent
        ? new TextEncoder().encode(currentEvent.message).length +
          AWS_CLOUDWATCH_BASE_BUFFER_SIZE
        : 0;
      if (eventSize > AWS_CLOUDWATCH_MAX_EVENT_SIZE) {
        const errString = `Log entry exceeds maximum size for CloudWatch logs. Log size: ${eventSize}. Truncating log message.`;
        logger.warn(errString);

        currentEvent.message = currentEvent.message.substring(0, eventSize);
      }

      if (totalByteSize + eventSize > AWS_CLOUDWATCH_MAX_BATCH_EVENT_SIZE)
        break;
      totalByteSize += eventSize;
      currentEventIdx++;
    }

    this._currentLogBatch = this._dataTracker.logEvents.splice(
      0,
      currentEventIdx,
    );

    return this._currentLogBatch;
  }

  private async _getNewSequenceTokenAndSubmit(
    payload: PutLogEventsCommandInput,
  ): Promise<PutLogEventsCommandOutput> {
    try {
      this._nextSequenceToken = undefined;
      this._dataTracker.eventUploadInProgress = true;

      const seqToken = await this._getNextSequenceToken();
      payload.sequenceToken = seqToken;
      const sendLogEventsRepsonse = await this._sendLogEvents(payload);

      this._dataTracker.eventUploadInProgress = false;
      this._currentLogBatch = [];

      return sendLogEventsRepsonse;
    } catch (err) {
      logger.error(
        `error when retrying log submission with new sequence token: ${err}`,
      );
      this._dataTracker.eventUploadInProgress = false;

      throw err;
    }
  }

  private _initiateLogPushInterval(): void {
    if (this._timer) {
      clearInterval(this._timer);
    }

    this._timer = setInterval(async () => {
      try {
        if (this._getDocUploadPermissibility()) {
          await this._safeUploadLogEvents();
        }
      } catch (err) {
        logger.error(
          `error when calling _safeUploadLogEvents in the timer interval - ${err}`,
        );
      }
    }, 2000);
  }

  private _getDocUploadPermissibility(): boolean {
    return (
      (this._dataTracker.logEvents.length !== 0 ||
        this._currentLogBatch.length !== 0) &&
      !this._dataTracker.eventUploadInProgress
    );
  }
}

export { AWSCloudWatchProvider };
