/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/
/*
Notice: This file has been modified for Hyperledger Fabric SDK Go usage.
Please review third_party pinning scripts and patches for more details.
*/

package keyutil

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/x509"
	"encoding/asn1"
	"encoding/pem"
	"errors"
	"fmt"

	cspx509 "github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/bccsp/x509"
	sm2 "github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/bctls/tencentsm2"
	flogging "github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/sdkpatch/logbridge"
)

var (
	logger = flogging.MustGetLogger("fabric-ca_keys")
)

var (
	oidNamedCurveP224 = asn1.ObjectIdentifier{1, 3, 132, 0, 33}
	oidNamedCurveP256 = asn1.ObjectIdentifier{1, 2, 840, 10045, 3, 1, 7}
	oidNamedCurveP384 = asn1.ObjectIdentifier{1, 3, 132, 0, 34}
	oidNamedCurveP521 = asn1.ObjectIdentifier{1, 3, 132, 0, 35}
	oidNamedCurveSm2  = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301}
)

var oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1}

func oidFromNamedCurve(curve elliptic.Curve) (asn1.ObjectIdentifier, bool) {
	switch curve {
	case elliptic.P224():
		return oidNamedCurveP224, true
	case elliptic.P256():
		return oidNamedCurveP256, true
	case elliptic.P384():
		return oidNamedCurveP384, true
	case elliptic.P521():
		return oidNamedCurveP521, true
	}
	return nil, false
}

//func PrivateKeyToDER(privateKey *ecdsa.PrivateKey) ([]byte, error) {
//	if privateKey == nil {
//		return nil, errors.New("invalid ecdsa private key. It must be different from nil")
//	}
//
//	return x509.MarshalECPrivateKey(privateKey)
//}

func PrivateKeyToDER(privateKey interface{}) ([]byte, error) {
	if privateKey == nil {
		return nil, errors.New("invalid EC private key. It must be different from nil")
	}
	switch privateKey.(type) {
	case *sm2.PrivateKey:
		return sm2.MarshalPrivateKey(privateKey.(*sm2.PrivateKey))
	case *ecdsa.PrivateKey:
		return x509.MarshalECPrivateKey(privateKey.(*ecdsa.PrivateKey))
	default:
		return nil, errors.New("Failed to recognize the privateKey")
	}
}

//func DERToPrivateKey(der []byte) (key interface{}, err error) {
//
//	if key, err = x509.ParsePKCS1PrivateKey(der); err == nil {
//		return key, nil
//	}
//
//	if key, err = x509.ParsePKCS8PrivateKey(der); err == nil {
//		switch key.(type) {
//		case *ecdsa.PrivateKey:
//			return
//		default:
//			return nil, errors.New("found unknown private key type in PKCS#8 wrapping")
//		}
//	}
//
//	if key, err = x509.ParseECPrivateKey(der); err == nil {
//		return
//	}
//
//	return nil, errors.New("fabric ca :invalid key type. The DER must contain an ecdsa.PrivateKey")
//}

func DERToPrivateKey(der []byte) (key interface{}, err error) {

	if key, err = cspx509.ParsePKCS1PrivateKey(der); err == nil {
		logger.Debug("PKCS1 type private key recognized")
		return key, nil
	}

	if key, err = x509.ParsePKCS8PrivateKey(der); err == nil {
		switch key.(type) {
		case *ecdsa.PrivateKey, *sm2.PrivateKey:
			logger.Debug("PKCS8 type private key recognized")
			return
		default:
			return nil, errors.New("found unknown private key type in PKCS#8 wrapping")
		}
	}

	if key, err = x509.ParseECPrivateKey(der); err == nil {
		oid, ok := oidFromNamedCurve(key.(*ecdsa.PrivateKey).Curve)
		if ok && !oid.Equal(oidNamedCurveSm2) {
			logger.Debug("ECDSA private key recognized")
			return
		}
	}

	//sm2
	if key, err = sm2.UnmarshalPrivateKey(der); err == nil {
		logger.Debug("SM2 private key recognized")
		return key, nil
	}

	return nil, errors.New("invalid key type. The DER must contain an ecdsa.PrivateKey or sm2.PrivateKey")
}

//func PEMToPrivateKey(raw []byte, pwd []byte) (interface{}, error) {
//	block, _ := pem.Decode(raw)
//	if block == nil {
//		return nil, fmt.Errorf("failed decoding PEM. Block must be different from nil [% x]", raw)
//	}
//
//	// TODO: derive from header the type of the key
//
//	if x509.IsEncryptedPEMBlock(block) {
//		if len(pwd) == 0 {
//			return nil, errors.New("encrypted Key. Need a password")
//		}
//
//		decrypted, err := x509.DecryptPEMBlock(block, pwd)
//		if err != nil {
//			return nil, fmt.Errorf("failed PEM decryption: [%s]", err)
//		}
//
//		key, err := DERToPrivateKey(decrypted)
//		if err != nil {
//			return nil, err
//		}
//		return key, err
//	}
//
//	cert, err := DERToPrivateKey(block.Bytes)
//	if err != nil {
//		return nil, err
//	}
//	return cert, err
//}

func PEMToPrivateKey(raw []byte, pwd []byte) (interface{}, error) {
	if len(raw) == 0 {
		return nil, errors.New("Invalid PEM. It must be different from nil.")
	}
	block, _ := pem.Decode(raw)
	if block == nil {
		return nil, fmt.Errorf("failed decoding PEM. Block must be different from nil [% x]", raw)
	}

	// TODO: derive from header the type of the key

	if x509.IsEncryptedPEMBlock(block) {
		if len(pwd) == 0 {
			return nil, errors.New("encrypted Key. Need a password")
		}

		decrypted, err := x509.DecryptPEMBlock(block, pwd)
		if err != nil {
			return nil, fmt.Errorf("failed PEM decryption: [%s]", err)
		}

		key, err := DERToPrivateKey(decrypted)
		if err != nil {
			return nil, err
		}
		return key, err
	}

	cert, err := DERToPrivateKey(block.Bytes)
	if err != nil {
		return nil, err
	}
	return cert, err
}
