package sm2

import (
	"bytes"
	gocrypto "crypto"
	//"crypto/cipher"
	"crypto/elliptic"
	"crypto/rand"
	"encoding/asn1"
	"errors"
	crypto "github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/bctls"
	tjfoc "github.com/tjfoc/gmsm/sm2"
	"github.com/tjfoc/gmsm/sm3"
	"io"
	"math/big"
)

const (
	DEFAULT_USER_ID = "1234567812345678"
)

var errZeroParam = errors.New("zero parameter")
var one = new(big.Int).SetInt64(1)
var two = new(big.Int).SetInt64(2)

// key struct
type PrivateKey struct {
	PublicKey
	D *big.Int
	//k *gmssl.PrivateKey
}

type PublicKey struct {
	elliptic.Curve
	X, Y *big.Int
	//k    *gmssl.PublicKey
}


type sm2Signature struct {
	R, S *big.Int
}

func ParsePktoTJ (pk *PublicKey) *tjfoc.PublicKey {
	return &tjfoc.PublicKey{
		Curve: pk.Curve,
		X:     pk.X,
		Y:     pk.Y,
	}
}

func ParseTJtoPriv (tjpriv *tjfoc.PrivateKey) *PrivateKey {
	return &PrivateKey{
		PublicKey: PublicKey{
			Curve: tjpriv.Curve,
			X:     tjpriv.X,
			Y:     tjpriv.Y,
		},
		D:         tjpriv.D,
	}
}

func ParsePrivtoTJ(priv *PrivateKey) *tjfoc.PrivateKey  {
	return &tjfoc.PrivateKey{
		PublicKey: tjfoc.PublicKey{
			Curve: priv.Curve,
			X:     priv.X,
			Y:     priv.Y,
		},
		D:         priv.D,
	}
}

//Prepare pk+uid hash
func createZA(pk *PublicKey, uid string) ([]byte, error) {
	za := sm3.New()
	uidLen := len(uid)
	if uidLen >= 8192 {
		return []byte{}, errors.New("SM2: uid too large")
	}
	Entla := uint16(8 * uidLen)
	za.Write([]byte{byte((Entla >> 8) & 0xFF)})
	za.Write([]byte{byte(Entla & 0xFF)})
	if uidLen > 0 {
		za.Write([]byte(uid))
	}
	za.Write(P256ToBig(&p256.a).Bytes())
	za.Write(p256.B.Bytes())
	za.Write(p256.Gx.Bytes())
	za.Write(p256.Gy.Bytes())

	xBuf := pk.X.Bytes()
	yBuf := pk.Y.Bytes()
	if n := len(xBuf); n < 32 {
		xBuf = append(zeroByteSlice()[:32-n], xBuf...)
	}
	if n := len(yBuf); n < 32 {
		yBuf = append(zeroByteSlice()[:32-n], yBuf...)
	}
	za.Write(xBuf)
	za.Write(yBuf)
	return za.Sum(nil)[:32], nil
}

func zeroByteSlice() []byte {
	return []byte{
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
		0, 0, 0, 0,
	}
}

//Prepare pk+uid+message hash
func createDigest(zid, msg []byte) ([]byte, error) {
	e := sm3.New()
	e.Write(zid)
	e.Write(msg)
	return new(big.Int).SetBytes(e.Sum(nil)[:32]).Bytes(), nil
}

func ComputeSM3Digest(pk *PublicKey, uid string, msg []byte) ([]byte, error) {
	if len(uid) == 0 {
		uid = DEFAULT_USER_ID
	}
	zid, err := createZA(pk, uid)
	if err != nil {
		return nil, err
	}
	return createDigest(zid, msg)
}

// key generation in GmSSL library fails with a small probability,
// but will recover when retried. Hence, we retry at most 100 times on failure.
func GenerateKey() (*PrivateKey, error) {
	c := P256Sm2()
	random := rand.Reader
	params := c.Params()
	b := make([]byte, params.BitSize/8+8)
	_, err := io.ReadFull(random, b)
	if err != nil {
		return nil, err
	}

	k := new(big.Int).SetBytes(b)
	n := new(big.Int).Sub(params.N, two)
	k.Mod(k, n)
	k.Add(k, one)
	priv := new(PrivateKey)
	priv.PublicKey.Curve = c
	priv.D = k
	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())

	return priv, nil
}

//func GenerateKey() (*PrivateKey, error) {
//	var err error
//
//	sm2keygenargs := map[string]string{
//		"ec_paramgen_curve": "sm2p256v1",
//		"ec_param_enc":      "named_curve",
//	}
//
//	var sm2sk *gmssl.PrivateKey
//	for i := 0; i < 100; i++ {
//		sm2sk, err = gmssl.GeneratePrivateKey("EC", sm2keygenargs, nil)
//		if err == nil {
//			break
//		}
//	}
//	if err != nil {
//		return nil, err
//	}
//
//	skPEM, err := sm2sk.GetUnencryptedPEM()
//	if err != nil {
//		return nil, err
//	}
//
//	return PrivateKeyFromPEM(skPEM)
//}

func Sign(priv *PrivateKey, msg []byte) ([]byte, error) {
	//return priv.k.Sign("sm2sign", msg, nil)
	r, s, err := SignwithoutSm3(priv, msg, nil)
	if err != nil {
		return nil, err
	}
	return asn1.Marshal(sm2Signature{r, s})
}

func SignwithoutSm3(priv *PrivateKey, msg, uid []byte) (r, s *big.Int, err error) {
	if len(uid) == 0 {
		uid=[]byte(DEFAULT_USER_ID)
	}
	if err != nil {
		return nil, nil, err
	}
	c := priv.Curve
	N := c.Params().N
	if N.Sign() == 0 {
		return nil, nil, errZeroParam
	}
	e := new(big.Int).SetBytes(msg)
	var k *big.Int
	for { // 调整算法细节以实现SM2
		for {
			k, err = randFieldElement(c, rand.Reader)
			if err != nil {
				r = nil
				return
			}
			r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
			r.Add(r, e)
			r.Mod(r, N)
			if r.Sign() != 0 {
				if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
					break
				}
			}

		}
		rD := new(big.Int).Mul(priv.D, r)
		s = new(big.Int).Sub(k, rD)
		d1 := new(big.Int).Add(priv.D, one)
		d1Inv := new(big.Int).ModInverse(d1, N)
		s.Mul(s, d1Inv)
		s.Mod(s, N)
		if s.Sign() != 0 {
			break
		}
	}
	return
}

func Verify(pub *PublicKey, msg, sig []byte) bool {
	r, s, err := tjfoc.SignDataToSignDigit(sig)
	if err !=nil {
		return false
	}

	c := pub.Curve
	N := c.Params().N

	if r.Sign() <= 0 || s.Sign() <= 0 {
		return false
	}
	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
		return false
	}

	// 调整算法细节以实现SM2
	t := new(big.Int).Add(r, s)
	t.Mod(t, N)
	if t.Sign() == 0 {
		return false
	}

	var x *big.Int
	x1, y1 := c.ScalarBaseMult(s.Bytes())
	x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
	x, _ = c.Add(x1, y1, x2, y2)

	//md := sha512.New()
	//md.Write(msg)
	e := new(big.Int).SetBytes(msg)
	x.Add(x, e)
	x.Mod(x, N)
	return x.Cmp(r) == 0
	//return pub.k.Verify("sm2sign", msg, sig, nil) == nil
}

func SignWithSM3(priv *PrivateKey, msg []byte, uid string) ([]byte, error) {
	if len(uid) == 0 {
		uid= DEFAULT_USER_ID
	}
	zid, err := createZA(&priv.PublicKey, uid)
	if err != nil {
		return nil, err
	}
	digest, err := createDigest(zid, msg)
	if err != nil {
		return nil, err
	}
	return Sign(priv, digest)
}

func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
	if random == nil {
		random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
	}
	params := c.Params()
	b := make([]byte, params.BitSize/8+8)
	_, err = io.ReadFull(random, b)
	if err != nil {
		return
	}
	k = new(big.Int).SetBytes(b)
	n := new(big.Int).Sub(params.N, one)
	k.Mod(k, n)
	k.Add(k, one)
	return
}

func VerifyWithSM3(pub *PublicKey, msg, sig []byte, uid string) bool {
	zid, err := createZA(pub, uid)
	if err != nil {
		return false
	}
	digest, err := createDigest(zid, msg)
	if err != nil {
		return false
	}
	return Verify(pub, digest, sig)
}

func Encrypt(pk *PublicKey, msg []byte) ([]byte, error) {
	tjpk := ParsePktoTJ(pk)
	return tjfoc.Encrypt(tjpk, msg, rand.Reader, 1)
}

func Decrypt(sk *PrivateKey, cipher []byte) ([]byte, error) {
	tjpriv := ParsePrivtoTJ(sk)
	return tjfoc.Decrypt(tjpriv, cipher, 1)
}

//CFCA证书若签名为31位，会补0，go本身是不补，长度写31
//兼容 去掉补0，长度改为31
func GetSignatureFromCFCA(signature []byte) []byte {
	dataLength := len(signature)
	dataIndex := 2 //当前下标，初始值为循环数据开始的位置

	//格式为 类型(1)+总长度(1)+[类型(1)+长度(1)+数据]
	//数据字节数为长度对应的大小，一般为32
	var signBuffer bytes.Buffer
	signBuffer.Write(signature[0:dataIndex])
	currentCount := signature[1]  //结构体总长度，用于减去补0后，总长度同样需要减
	currentDataCount := byte('0') //循环中有效数据实际长度
	dataCount := 0                //用于循环中记录每个数据的长度
	zeroCount := 0                //用于循环中记录出现的补0的个数
	for dataIndex+2 < dataLength {
		signBuffer.WriteByte(signature[dataIndex])
		dataCount = int(signature[dataIndex+1])
		if dataIndex+dataCount+2 > dataLength {
			signBuffer.Write(signature[dataIndex+1:])
			break
		}
		//只对长度为32字节的处理，如33字节表示正数但最高位为0需补符号，属于正常
		if 0 == signature[dataIndex+2] && 0 == signature[dataIndex+3]&0x80 {
			currentDataCount = signature[dataIndex+1] - 1
			zeroCount = 1
			//判断是否补多个0
			for {
				if 0 == signature[dataIndex+2+zeroCount] && 0 == signature[dataIndex+3+zeroCount]&0x80 {
					currentDataCount -= 1
					zeroCount += 1
				} else {
					break
				}
			}
			signBuffer.WriteByte(currentDataCount)
			signBuffer.Write(signature[dataIndex+2+zeroCount : dataIndex+2+dataCount])
			currentCount -= signature[dataIndex+1] - currentDataCount
		} else {
			signBuffer.Write(signature[dataIndex+1 : dataIndex+dataCount+2])
		}

		dataIndex += dataCount + 2
	}

	signature = signBuffer.Bytes()

	if 0 < signature[1]-currentCount {
		signature[1] = currentCount
	}

	return signature
}

func (sk *PrivateKey) Public() gocrypto.PublicKey {
	return &sk.PublicKey
}

func (sk *PrivateKey) Sign(rand io.Reader, digest []byte, opts gocrypto.SignerOpts) (signature []byte, err error) {
	if opts == nil || opts != gocrypto.Hash(crypto.SM3) {
		return Sign(sk, digest)
	} else {
		return SignWithSM3(sk, digest, DEFAULT_USER_ID)
	}
}

func (pk *PublicKey) Verify(digest []byte, sig []byte, opts gocrypto.SignerOpts) bool {
	if opts == nil || opts != gocrypto.Hash(crypto.SM3) {
		return Verify(pk, digest, sig)
	} else {
		return VerifyWithSM3(pk, digest, sig, DEFAULT_USER_ID)
	}
}
