/*
 *  Unix SMB/CIFS implementation.
 *  libnet Join offline support
 *  Copyright (C) Guenther Deschner 2021
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

#include "includes.h"
#include "librpc/gen_ndr/ndr_libnet_join.h"
#include "../librpc/gen_ndr/ndr_ODJ.h"
#include "libnet/libnet_join_offline.h"
#include "libcli/security/dom_sid.h"
#include "rpc_client/util_netlogon.h"

static WERROR libnet_odj_compose_ODJ_WIN7BLOB(TALLOC_CTX *mem_ctx,
					      const struct libnet_JoinCtx *r,
					      struct ODJ_WIN7BLOB *b)
{
	char *samaccount;
	uint32_t len;
	struct ODJ_POLICY_DNS_DOMAIN_INFO i = {
		.Sid = NULL,
	};

	ZERO_STRUCTP(b);

	b->lpDomain = talloc_strdup(mem_ctx, r->out.dns_domain_name);
	if (b->lpDomain == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	samaccount = talloc_strdup(mem_ctx, r->out.account_name);
	if (samaccount == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}
	len = strlen(samaccount);
	if (samaccount[len-1] == '$') {
		samaccount[len-1] = '\0';
	}
	b->lpMachineName = samaccount;

	b->lpMachinePassword = talloc_strdup(mem_ctx, r->in.machine_password);
	if (b->lpMachinePassword == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	/* fill up ODJ_POLICY_DNS_DOMAIN_INFO */

	i.Name.string = talloc_strdup(mem_ctx, r->out.netbios_domain_name);
	if (i.Name.string == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	i.DnsDomainName.string = talloc_strdup(mem_ctx, r->out.dns_domain_name);
	if (i.DnsDomainName.string == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	i.DnsForestName.string = talloc_strdup(mem_ctx, r->out.forest_name);
	if (i.DnsForestName.string == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	i.DomainGuid = r->out.domain_guid;
	i.Sid = dom_sid_dup(mem_ctx, r->out.domain_sid);
	if (i.Sid == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	b->DnsDomainInfo = i;

	if (r->out.dcinfo) {
		struct netr_DsRGetDCNameInfo *p;

		p = talloc_steal(mem_ctx, r->out.dcinfo);
		if (p == NULL) {
			return WERR_NOT_ENOUGH_MEMORY;
		}

		b->DcInfo = *p;
	}

	/*
	 * According to
	 * https://docs.microsoft.com/en-us/windows/win32/netmgmt/odj-odj_win7blob
	 * it should be 0 but Windows 2019 always sets 6 - gd.
	 */
	b->Options = 6;

	return WERR_OK;
}

static WERROR libnet_odj_compose_OP_JOINPROV2_PART(TALLOC_CTX *mem_ctx,
						   const struct libnet_JoinCtx *r,
						   struct OP_JOINPROV2_PART **p)
{
	struct OP_JOINPROV2_PART *b;

	b = talloc_zero(mem_ctx, struct OP_JOINPROV2_PART);
	if (b == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	/* TODO */

	*p = b;

	return WERR_INVALID_LEVEL;
}

static WERROR libnet_odj_compose_OP_JOINPROV3_PART(TALLOC_CTX *mem_ctx,
						   const struct libnet_JoinCtx *r,
						   struct OP_JOINPROV3_PART **p)
{
	struct OP_JOINPROV3_PART *b;
	struct dom_sid *sid;

	b = talloc_zero(mem_ctx, struct OP_JOINPROV3_PART);
	if (b == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	b->Rid = r->out.account_rid;
	sid = dom_sid_add_rid(mem_ctx, r->out.domain_sid, r->out.account_rid);
	if (sid == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	b->lpSid = dom_sid_string(mem_ctx, sid);
	if (b->lpSid == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	*p = b;

	return WERR_OK;
}

static WERROR libnet_odj_compose_OP_PACKAGE_PART(TALLOC_CTX *mem_ctx,
						 const struct libnet_JoinCtx *r,
						 const struct ODJ_WIN7BLOB *win7,
						 const char *join_provider_guid,
						 uint32_t flags,
						 struct OP_PACKAGE_PART *p)
{
	struct GUID guid;
	uint32_t level;
	WERROR werr;

	if (!NT_STATUS_IS_OK(GUID_from_string(join_provider_guid, &guid))) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	level = odj_switch_level_from_guid(&guid);

	p->PartType	= guid;
	p->ulFlags	= flags;
	p->part_len	= 0; /* autogenerated */
	p->Part = talloc_zero(mem_ctx, union OP_PACKAGE_PART_u);
	if (p->Part == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	switch (level) {
		case 1: /* ODJ_GUID_JOIN_PROVIDER */
			if (win7 == NULL) {
				return WERR_INVALID_PARAMETER;
			}
			p->Part->win7blob = *win7;
			break;
		case 2: /* ODJ_GUID_JOIN_PROVIDER2 */
			werr = libnet_odj_compose_OP_JOINPROV2_PART(mem_ctx, r,
					&p->Part->join_prov2.p);
			if (!W_ERROR_IS_OK(werr)) {
				return werr;
			}
			break;
		case 3: /* ODJ_GUID_JOIN_PROVIDER3 */
			werr = libnet_odj_compose_OP_JOINPROV3_PART(mem_ctx, r,
					&p->Part->join_prov3.p);
			if (!W_ERROR_IS_OK(werr)) {
				return werr;
			}
			break;
		default:
			return WERR_INVALID_LEVEL;
	}

	return WERR_OK;
}

static WERROR libnet_odj_compose_OP_PACKAGE_PART_COLLECTION(TALLOC_CTX *mem_ctx,
							    const struct libnet_JoinCtx *r,
							    const struct ODJ_WIN7BLOB *win7,
							    struct OP_PACKAGE_PART_COLLECTION **pp)
{
	WERROR werr;
	struct OP_PACKAGE_PART_COLLECTION *p;

	p = talloc_zero(mem_ctx, struct OP_PACKAGE_PART_COLLECTION);
	if (p == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	p->cParts = 2;
	p->pParts = talloc_zero_array(p, struct OP_PACKAGE_PART, p->cParts);
	if (p->pParts == NULL) {
		talloc_free(p);
		return WERR_NOT_ENOUGH_MEMORY;
	}

	werr = libnet_odj_compose_OP_PACKAGE_PART(p, r, win7,
						  ODJ_GUID_JOIN_PROVIDER,
						  OPSPI_PACKAGE_PART_ESSENTIAL,
						  &p->pParts[0]);
	if (!W_ERROR_IS_OK(werr)) {
		talloc_free(p);
		return werr;
	}

	werr = libnet_odj_compose_OP_PACKAGE_PART(p, r, NULL,
						  ODJ_GUID_JOIN_PROVIDER3,
						  0,
						  &p->pParts[1]);
	if (!W_ERROR_IS_OK(werr)) {
		talloc_free(p);
		return werr;
	}

	*pp = p;

	return WERR_OK;
}

static WERROR libnet_odj_compose_OP_PACKAGE(TALLOC_CTX *mem_ctx,
					    const struct libnet_JoinCtx *r,
					    const struct ODJ_WIN7BLOB *win7,
					    struct OP_PACKAGE **pp)
{
	WERROR werr;
	struct OP_PACKAGE_PART_COLLECTION *c;
	struct OP_PACKAGE *p;

	p = talloc_zero(mem_ctx, struct OP_PACKAGE);
	if (p == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	werr = libnet_odj_compose_OP_PACKAGE_PART_COLLECTION(p, r, win7, &c);
	if (!W_ERROR_IS_OK(werr)) {
		talloc_free(p);
		return werr;
	}

	p->EncryptionType = GUID_zero();

	p->WrappedPartCollection.cbBlob = 0; /* autogenerated */
	p->WrappedPartCollection.w = talloc_zero(p,
			struct OP_PACKAGE_PART_COLLECTION_serialized_ptr);
	if (p->WrappedPartCollection.w == NULL) {
		talloc_free(p);
		return WERR_NOT_ENOUGH_MEMORY;
	}

	p->WrappedPartCollection.w->s.p = c;

	*pp = p;

	return WERR_OK;
}

WERROR libnet_odj_compose_ODJ_PROVISION_DATA(TALLOC_CTX *mem_ctx,
					     const struct libnet_JoinCtx *r,
					     struct ODJ_PROVISION_DATA **b_p)
{
	WERROR werr;
	struct ODJ_PROVISION_DATA *b;
	struct ODJ_WIN7BLOB win7;
	struct OP_PACKAGE *package;

	b = talloc_zero(mem_ctx, struct ODJ_PROVISION_DATA);
	if (b == NULL) {
		return WERR_NOT_ENOUGH_MEMORY;
	}

	b->ulVersion	= 1;
	b->ulcBlobs	= 2;
	b->pBlobs	= talloc_zero_array(b, struct ODJ_BLOB, b->ulcBlobs);
	if (b->pBlobs == NULL) {
		talloc_free(b);
		return WERR_NOT_ENOUGH_MEMORY;
	}

	werr = libnet_odj_compose_ODJ_WIN7BLOB(b, r, &win7);
	if (!W_ERROR_IS_OK(werr)) {
		talloc_free(b);
		return werr;
	}

	werr = libnet_odj_compose_OP_PACKAGE(b, r, &win7, &package);
	if (!W_ERROR_IS_OK(werr)) {
		talloc_free(b);
		return werr;
	}

	b->pBlobs[0].ulODJFormat = ODJ_WIN7_FORMAT;
	b->pBlobs[0].cbBlob = 0; /* autogenerated */
	b->pBlobs[0].pBlob = talloc_zero(b, union ODJ_BLOB_u);
	if (b->pBlobs[0].pBlob == NULL) {
		talloc_free(b);
		return WERR_NOT_ENOUGH_MEMORY;
	}
	b->pBlobs[0].pBlob->odj_win7blob = win7;

	b->pBlobs[1].ulODJFormat = ODJ_WIN8_FORMAT;
	b->pBlobs[1].cbBlob = 0; /* autogenerated */
	b->pBlobs[1].pBlob = talloc_zero(b, union ODJ_BLOB_u);
	if (b->pBlobs[1].pBlob == NULL) {
		talloc_free(b);
		return WERR_NOT_ENOUGH_MEMORY;
	}
	b->pBlobs[1].pBlob->op_package.p = package;

	*b_p = b;

	return WERR_OK;
}

WERROR libnet_odj_find_win7blob(const struct ODJ_PROVISION_DATA *r,
				struct ODJ_WIN7BLOB *win7blob)
{
	int i;

	if (r == NULL) {
		return WERR_INVALID_PARAMETER;
	}

	for (i = 0; i < r->ulcBlobs; i++) {

		struct ODJ_BLOB b = r->pBlobs[i];

		switch (b.ulODJFormat) {
		case ODJ_WIN7_FORMAT:
			*win7blob = b.pBlob->odj_win7blob;
			return WERR_OK;

		case ODJ_WIN8_FORMAT: {
			NTSTATUS status;
			struct OP_PACKAGE_PART_COLLECTION *col;
			struct GUID guid;
			int k;

			if (b.pBlob->op_package.p->WrappedPartCollection.w == NULL) {
				return WERR_BAD_FORMAT;
			}

			col = b.pBlob->op_package.p->WrappedPartCollection.w->s.p;

			status = GUID_from_string(ODJ_GUID_JOIN_PROVIDER, &guid);
			if (!NT_STATUS_IS_OK(status)) {
				return WERR_NOT_ENOUGH_MEMORY;
			}

			for (k = 0; k < col->cParts; k++) {
				if (GUID_equal(&guid, &col->pParts[k].PartType)) {
					*win7blob = col->pParts[k].Part->win7blob;
					return WERR_OK;
				}
			}
			break;
		}
		default:
			return WERR_BAD_FORMAT;
		}
	}

	return WERR_BAD_FORMAT;
}


WERROR libnet_odj_find_joinprov3(const struct ODJ_PROVISION_DATA *r,
				 struct OP_JOINPROV3_PART *joinprov3)
{
	int i;

	if (r == NULL) {
		return WERR_INVALID_PARAMETER;
	}

	for (i = 0; i < r->ulcBlobs; i++) {

		struct ODJ_BLOB b = r->pBlobs[i];

		switch (b.ulODJFormat) {
		case ODJ_WIN7_FORMAT:
			continue;

		case ODJ_WIN8_FORMAT: {
			NTSTATUS status;
			struct OP_PACKAGE_PART_COLLECTION *col;
			struct GUID guid;
			int k;

			if (b.pBlob->op_package.p->WrappedPartCollection.w == NULL) {
				return WERR_BAD_FORMAT;
			}

			col = b.pBlob->op_package.p->WrappedPartCollection.w->s.p;

			status = GUID_from_string(ODJ_GUID_JOIN_PROVIDER3, &guid);
			if (!NT_STATUS_IS_OK(status)) {
				return WERR_NOT_ENOUGH_MEMORY;
			}

			for (k = 0; k < col->cParts; k++) {
				if (GUID_equal(&guid, &col->pParts[k].PartType)) {
					*joinprov3 = *col->pParts[k].Part->join_prov3.p;
					return WERR_OK;
				}
			}
			break;
		}
		default:
			return WERR_BAD_FORMAT;
		}
	}

	return WERR_BAD_FORMAT;
}
