import { inject, Injectable } from '@angular/core';
import { DeviceKeyAlgorithm, OlmDeviceModel, OlmOneTimeKeys } from '@portal/wen-backend-api';
import { forkJoin, Observable, of } from 'rxjs';
import { defaultIfEmpty, first, map, shareReplay, switchMap } from 'rxjs/operators';
import { ChatBackendApi, ChatUserApi, GetOneTimeKeyResult } from '../backend-api/backend-api';
import { OlmDevice } from '../device/olm-device';
import { CryptoStorage } from '../persistence/crypto-storage';
import { getCurrentDateTime } from '../util/date-util';
import { findBestFittingSession } from '../util/find-best-fitting-session';
import { OlmEncryptionResult } from './crypto-results';

type SessionResolveResult = {
  session: string;
  sessionId: string;
  deviceId: string;
};

class OneTimeKeyRequest {

  constructor(
    private chatBackendApi: ChatBackendApi,
    private userId: string,
  ) { }

  private oneTimeKey$: Observable<GetOneTimeKeyResult>;

  get(deviceId: string): Observable<OlmOneTimeKeys> {
    if (!this.oneTimeKey$) {
      this.oneTimeKey$ = this.chatBackendApi.getOneTimeKey({ userId: this.userId }).pipe(
        first(),
        shareReplay(1),
      );
    }
    return this.oneTimeKey$.pipe(
      map((oneTimeKeys) => {
        return oneTimeKeys.find((oneTimeKey) => {
          return oneTimeKey.deviceId === deviceId && oneTimeKey.userId === this.userId;
        });
      })
    );
  }

}

@Injectable()
export class OlmEncryptor {

  private olmDevice = inject(OlmDevice);
  private chatBackendApi = inject(ChatBackendApi);
  private chatUserApi = inject(ChatUserApi);
  private dataStore = inject(CryptoStorage);

  private resolveDevicesFor(userId: string[]): Observable<OlmDeviceModel[]> {
    return this.chatBackendApi.getDeviceKeys({ userIds: userId }).pipe(
      switchMap((downloadedDevices) => {
        return this.dataStore.storeDevices(downloadedDevices.deviceKeys).pipe(
          map((() => downloadedDevices.deviceKeys))
        );
      })
    );
  }

  private resolveSessionFor(curve25519: string, deviceId: string, oneTimeKeyRequest: OneTimeKeyRequest): Observable<SessionResolveResult> {
    return this.dataStore.getSessionsForDevice({ curve25519 }).pipe(
      map((sessionModels) => {
        const targetSession = findBestFittingSession(sessionModels);
        return targetSession;
      }),
      switchMap((sessionModel) => {
        if (sessionModel) {
          const result: SessionResolveResult = {
            session: sessionModel.session,
            sessionId: sessionModel.sessionId,
            deviceId
          };
          return of(result);
        }
        return oneTimeKeyRequest.get(deviceId).pipe(
          first(),
          switchMap((getOneTimeKeyResult) => {
            if (!getOneTimeKeyResult) {
              return of(null);
            }
            const oneTimeKey = Object.values(getOneTimeKeyResult.keys)[0];
            const sessionData = this.olmDevice.createOutboundSession(curve25519, oneTimeKey);
            const storeAcc$ = this.dataStore.storeAccount(this.olmDevice.getAccountModel());
            const storeSession$ = this.dataStore.storeSession({
              curve25519,
              session: sessionData.session,
              sessionId: sessionData.sessionId,
              lastActivityTimestamp: getCurrentDateTime()
            });
            return forkJoin([storeAcc$, storeSession$]).pipe(
              map(() => {
                const result: SessionResolveResult = {
                  session: sessionData.session,
                  sessionId: sessionData.sessionId,
                  deviceId
                };
                return result;
              })
            );
          })
        );
      }),
    );
  }

  private ensureOutboundSessions(userId: string, devices: OlmDeviceModel[], targetDeviceId: string): Observable<SessionResolveResult[]> {
    const oneTimeKeyRequest = new OneTimeKeyRequest(this.chatBackendApi, userId);
    const allSessionsForAllDevices$ = this.chatUserApi.getUserData().pipe(
      map((userData) => {
        const targetDevices = devices
          .filter(device => {
            const isCurvekey = device.algorithm === DeviceKeyAlgorithm.curve25519;
            const isTargetDevice = targetDeviceId ? device.deviceId === targetDeviceId : true;
            const isThisDevice = userData.deviceId === device.deviceId;
            const isTargetUser = device.userId === userId;
            return isCurvekey && isTargetDevice && isTargetUser && !isThisDevice;
          });

        return targetDevices;
      }),
      switchMap((targetDevices) => {
        if (!targetDevices.length) {
          return of([]);
        }
        const sessions$ = targetDevices.map(targetDevice => {
          return this.resolveSessionFor(targetDevice.key, targetDevice.deviceId, oneTimeKeyRequest);
        });
        return forkJoin(sessions$).pipe(
          map((sessions => sessions.filter(session => Boolean(session))))
        );
      }),
      shareReplay(1)
    );
    return allSessionsForAllDevices$;
  }

  encryptMessage(targetUserId: string | string[], content: string, targetDeviceId?: string): Observable<OlmEncryptionResult[]> {
    const userIds = Array.isArray(targetUserId) ? targetUserId : [targetUserId];
    const senderCurve25519 = this.olmDevice.getIdentityKeys().curve25519;
    return this.resolveDevicesFor(userIds).pipe(
      switchMap((devices) => {
        const results$ = userIds.map((userId) => {
          return this.ensureOutboundSessions(userId, devices, targetDeviceId).pipe(
            map((sessionsForUser) => {
              const encryptionResults = sessionsForUser.map(sessionForUser => {
                const { encrypted, usedSession } = this.olmDevice.encryptMessage(sessionForUser.session, content);
                const encryptionResult: OlmEncryptionResult = {
                  encrypted,
                  usedSession,
                  targetDeviceId: sessionForUser.deviceId,
                  targetUserId: userId,
                  senderCurve25519
                };
                return encryptionResult;
              });
              return encryptionResults;
            })
          );
        });
        return forkJoin(results$).pipe(
          map((results) => Array.prototype.concat.apply([], results)),
          defaultIfEmpty([]),
        );
      })
    );
  }

}
