#include "GamebaseVerifySignatureHelper.h"

#include "Windows/AllowWindowsPlatformTypes.h"
#include <wintrust.h>
#include <softpub.h>
#pragma comment(lib, "wintrust.lib")
#include "Windows/HideWindowsPlatformTypes.h"

#include "Misc/Paths.h"
#include "HAL/PlatformFilemanager.h"
#include "Misc/FeedbackContext.h"

namespace GamebaseVerifySignatureHelper
{
    class FScopedCertContext
    {
    public:
        explicit FScopedCertContext(const PCCERT_CONTEXT Context) : Context(Context) {}
        ~FScopedCertContext()
        {
            if (Context) CertFreeCertificateContext(Context);
        }
        
    private:
        PCCERT_CONTEXT Context;
    };

    class FScopedCertStore
    {
    public:
        explicit FScopedCertStore(const HCERTSTORE Store) : Store(Store) {}
        ~FScopedCertStore()
        {
            if (Store) CertCloseStore(Store, 0);
        }
        
    private:
        HCERTSTORE Store;
    };

    class FScopedCryptMsg
    {
    public:
        explicit FScopedCryptMsg(const HCRYPTMSG Msg) : Msg(Msg) {}
        ~FScopedCryptMsg()
        {
            if (Msg) CryptMsgClose(Msg);
        }
        
    private:
        HCRYPTMSG Msg;
    };
}

TOptional<FString> GamebaseVerifySignatureHelper::VerifyDll(const FString& DllPath, const FString& ExpectedSignerName)
{
    if (DllPath.IsEmpty())
    {
        return FString("DLL path is empty.");
    }

    const FString FullPath = FPaths::ConvertRelativePathToFull(DllPath);

    WINTRUST_FILE_INFO FileData = {};
    FileData.cbStruct = sizeof(WINTRUST_FILE_INFO);
    FileData.pcwszFilePath = *FullPath;
    FileData.hFile = nullptr;
    FileData.pgKnownSubject = nullptr;

    WINTRUST_DATA WinTrustData = {};
    WinTrustData.cbStruct = sizeof(WINTRUST_DATA);
    WinTrustData.dwUIChoice = WTD_UI_NONE;
    WinTrustData.fdwRevocationChecks = WTD_REVOKE_NONE;
    WinTrustData.dwUnionChoice = WTD_CHOICE_FILE;
    WinTrustData.pFile = &FileData;
    WinTrustData.dwStateAction = WTD_STATEACTION_VERIFY;
    WinTrustData.dwProvFlags = WTD_SAFER_FLAG;
    WinTrustData.hWVTStateData = nullptr;

    GUID WvtPolicyGUID = WINTRUST_ACTION_GENERIC_VERIFY_V2;

    LONG Status = WinVerifyTrust(nullptr, &WvtPolicyGUID, &WinTrustData);

    if (WinTrustData.hWVTStateData != nullptr)
    {
        WinTrustData.dwStateAction = WTD_STATEACTION_CLOSE;
        WinVerifyTrust(nullptr, &WvtPolicyGUID, &WinTrustData);
    }

    if (Status != ERROR_SUCCESS)
    {
        return FString::Printf(TEXT("WinVerifyTrust failed with error: 0x%08X"), Status);
    }

    HCERTSTORE hStore = nullptr;
    HCRYPTMSG hMsg = nullptr;

    DWORD Encoding = 0, ContentType = 0, FormatType = 0;
    if (!CryptQueryObject(
        CERT_QUERY_OBJECT_FILE,
        *FullPath,
        CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED,
        CERT_QUERY_FORMAT_FLAG_BINARY,
        0,
        &Encoding,
        &ContentType,
        &FormatType,
        &hStore,
        &hMsg,
        nullptr))
    {
        return FString("Failed to query object from DLL for signer info.");
    }

    FScopedCertStore ScopedStore(hStore);
    FScopedCryptMsg ScopedMsg(hMsg);

    DWORD cbSignerInfo = 0;
    if (!CryptMsgGetParam(
        hMsg,
        CMSG_SIGNER_INFO_PARAM,
        0,
        nullptr,
        &cbSignerInfo))
    {
        return FString("Failed to get signer info param size.");
    }

    TArray<BYTE> SignerInfoData;
    SignerInfoData.SetNumUninitialized(cbSignerInfo);

    if (!CryptMsgGetParam(
        hMsg,
        CMSG_SIGNER_INFO_PARAM,
        0,
        SignerInfoData.GetData(),
        &cbSignerInfo))
    {
        return FString("Failed to get signer info data.");
    }

    PCMSG_SIGNER_INFO pSignerInfo = reinterpret_cast<PCMSG_SIGNER_INFO>(SignerInfoData.GetData());

    CERT_INFO CertInfo = {};
    CertInfo.Issuer = pSignerInfo->Issuer;
    CertInfo.SerialNumber = pSignerInfo->SerialNumber;

    PCCERT_CONTEXT pCertContext = CertFindCertificateInStore(
        hStore,
        X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
        0,
        CERT_FIND_SUBJECT_CERT,
        &CertInfo,
        nullptr);

    if (!pCertContext)
    {
        return FString("Failed to find certificate context.");
    }

    FScopedCertContext ScopedContext(pCertContext);

    TCHAR NameBuffer[512];
    DWORD NameLength = CertGetNameString(
        pCertContext,
        CERT_NAME_SIMPLE_DISPLAY_TYPE,
        0,
        nullptr,
        NameBuffer,
        UE_ARRAY_COUNT(NameBuffer));

    if (NameLength <= 1)
    {
        return FString("Failed to get signer name.");
    }

    const FString SignerName(NameBuffer);

    if (!SignerName.Equals(ExpectedSignerName, ESearchCase::IgnoreCase))
    {
        return FString::Printf(TEXT("Signer name mismatch. Expected: [%s], Actual: [%s]"), *ExpectedSignerName, *SignerName);
    }

    return TOptional<FString>();
}

