package sm2

import (
	"crypto/elliptic"
	"encoding/asn1"
	"encoding/pem"
	"errors"
	"fmt"
	"github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/bctls/x509/pkix"
	"math/big"
	"reflect"
)

type publicKeyInfo struct {
	Raw       asn1.RawContent
	Algorithm pkix.AlgorithmIdentifier
	PublicKey asn1.BitString
}

type pkixPublicKey struct {
	Algo      pkix.AlgorithmIdentifier
	BitString asn1.BitString
}

type pkcs8 struct {
	Version    int
	Algo       pkix.AlgorithmIdentifier
	PrivateKey []byte
}

type sm2PrivateKeyInfo struct {
	Version       int
	PrivateKey    []byte
	NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
	PublicKey     asn1.BitString        `asn1:"optional,explicit,tag:1"`
}

// public key conversion
func MarshalPublicKey(key *PublicKey) ([]byte, error) {
	var r pkixPublicKey
	var algo pkix.AlgorithmIdentifier

	algo.Algorithm = OidSM2
	algo.Parameters.Class = 0
	algo.Parameters.Tag = 6
	algo.Parameters.IsCompound = false
	algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301})
	r.Algo = algo
	r.BitString = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}
	return asn1.Marshal(r)
}

func ParsePublicKey(der []byte) (*PublicKey, error) {
	pemEncoded := pem.EncodeToMemory(&pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: der,
	})
	return PublicKeyFromPEM(string(pemEncoded))
}

func PublicKeyFromPEM(pkPEM string) (*PublicKey, error) {
	//pk, err := gmssl.NewPublicKeyFromPEM(pkPEM)
	//if err != nil {
	//	return nil, err
	//}
	//
	//var pubkey pkixPublicKey
	//
	//dem, err := pk.GetPEM()
	//if err != nil {
	//	return nil, err
	//}
	derBlock, _ := pem.Decode([]byte(pkPEM))

	var pubkey pkixPublicKey

	if _, err := asn1.Unmarshal(derBlock.Bytes, &pubkey); err != nil {
		return nil, err
	}
	if !reflect.DeepEqual(pubkey.Algo.Algorithm, OidSM2) {
		return nil, errors.New("x509: not sm2 elliptic curve")
	}
	curve := P256Sm2()
	x, y := elliptic.Unmarshal(curve, pubkey.BitString.Bytes)
	pub := &PublicKey{
		Curve: curve,
		X:     x,
		Y:     y,
	}
	return pub, nil
}

//func PublicKeyToPEM(key *PublicKey) (string, error) {
//	return key.k.GetPEM()
//}

// private key conversion
type ecPrivateKey struct {
	Version       int
	PrivateKey    []byte
	NamedCurveOID asn1.ObjectIdentifier `asn1:"optional,explicit,tag:0"`
	PublicKey     asn1.BitString        `asn1:"optional,explicit,tag:1"`
}

const ecPrivKeyVersion = 1

func UnmarshalPrivateKey(der []byte) (*PrivateKey, error) {
	return UnmarshalPrivateKeyWithCurve(nil, der)
}

func UnmarshalPrivateKeyWithCurve(namedCurveOID *asn1.ObjectIdentifier, der []byte) (key *PrivateKey, err error) {
	var privKeyPKCS8 pkcs8
	var privateKeyBytes []byte
	_, err = asn1.Unmarshal(der, &privKeyPKCS8)
	if err == nil {
		if !reflect.DeepEqual(privKeyPKCS8.Algo.Algorithm, OidSM2) {
			return nil, errors.New("x509: not sm2 elliptic curve")
		}
		privateKeyBytes = privKeyPKCS8.PrivateKey
	} else {
		privateKeyBytes = der
	}

	var privKey ecPrivateKey
	if _, err := asn1.Unmarshal(privateKeyBytes, &privKey); err != nil {
		return nil, errors.New("x509: failed to parse sm2 private key: " + err.Error())
	}
	if privKey.Version != ecPrivKeyVersion {
		return nil, fmt.Errorf("x509: unknown sm2 private key version %d", privKey.Version)
	}
	curve := P256Sm2()

	k := new(big.Int).SetBytes(privKey.PrivateKey)
	curveOrder := curve.Params().N
	if k.Cmp(curveOrder) >= 0 {
		return nil, errors.New("x509: invalid elliptic curve private key value")
	}

	privateKey := make([]byte, (curveOrder.BitLen()+7)/8)

	// Some private keys have leading zero padding. This is invalid
	// according to [SEC1], but this code will ignore it.
	for len(privKey.PrivateKey) > len(privateKey) {
		if privKey.PrivateKey[0] != 0 {
			return nil, errors.New("x509: invalid private key length")
		}
		privKey.PrivateKey = privKey.PrivateKey[1:]
	}

	// Some private keys remove all leading zeros, this is also invalid
	// according to [SEC1] but since OpenSSL used to do this, we ignore
	// this too.
	copy(privateKey[len(privateKey)-len(privKey.PrivateKey):], privKey.PrivateKey)
	x, y := curve.ScalarBaseMult(privateKey)
	//pubKeyTJ := tjfoc.PublicKey{
	//	Curve: curve,
	//	X:     x,
	//	Y:     y,
	//}
	//privKeyTJ := tjfoc.PrivateKey{
	//	PublicKey: pubKeyTJ,
	//	D:         k,
	//}
	//pubKey := PublicKey{pubKeyTJ}
	//priv := PrivateKey{
	//	PublicKey: pubKey,
	//	K:         privKeyTJ,
	//}

	priv := &PrivateKey{
		PublicKey: PublicKey{
			Curve: curve,
			X:     x,
			Y:     y,
		},
		D:         k,
	}

	//pemEncoded := pem.EncodeToMemory(&pem.Block{
	//	Type:  "PRIVATE KEY",
	//	Bytes: der,
	//})
	//derBlock, _ := pem.Decode(pemEncoded)
	//sk, err := gmssl.NewPrivateKeyFromPEM(string(pemEncoded), "")
	//if err != nil {
	//	return nil, err
	//}
	//pkPEM, err := sk.GetPublicKeyPEM()
	//if err != nil {
	//	return nil, err
	//}
	//pk, err := gmssl.NewPublicKeyFromPEM(pkPEM)
	//if err != nil {
	//	return nil, err
	//}
	//
	//priv.k = sk
	//priv.PublicKey.k = pk

	return priv, nil
}

func MarshalPrivateKey(key *PrivateKey) ([]byte, error) {
	if key == nil {
		return nil, errors.New("input SM2 private key is null")
	}

	var r pkcs8
	var priv sm2PrivateKeyInfo
	var algo pkix.AlgorithmIdentifier

	algo.Algorithm = OidSM2
	algo.Parameters.Class = 0
	algo.Parameters.Tag = 6
	algo.Parameters.IsCompound = false
	algo.Parameters.FullBytes = []byte{6, 8, 42, 129, 28, 207, 85, 1, 130, 45} // asn1.Marshal(asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301})
	priv.Version = 1
	priv.NamedCurveOID = OidNamedCurveSm2
	priv.PublicKey = asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)}
	priv.PrivateKey = key.D.Bytes()
	r.Version = 0
	r.Algo = algo
	r.PrivateKey, _ = asn1.Marshal(priv)
	return asn1.Marshal(r)
}

func PrivateKeyFromPEM(skPEM string) (*PrivateKey, error) {
	pemBlock, _ := pem.Decode([]byte(skPEM))
	if pemBlock == nil {
		return nil, errors.New("X509: fail to unmarshal private key, PEM is invalid")
	}
	return UnmarshalPrivateKey(pemBlock.Bytes)
}

//func PrivateKeyToPEM(key *PrivateKey) (string, error) {
//	return key.k.GetUnencryptedPEM()
//}
