//SPDX-License-Identifier: Apache-2.0
package gmcredentials

import (
	"github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/bctls/tls"
	flogging "github.com/hyperledger/fabric-sdk-go/internal/github.com/hyperledger/fabric/sdkpatch/logbridge"
	"golang.org/x/net/context"
	"google.golang.org/grpc/credentials"
	"net"
)

var (
	// alpnProtoStr are the specified application level protocols for gRPC.
	alpnProtoStr = []string{"h2"}
	gmcredsLogger = flogging.MustGetLogger("gmcreds.tls")
)

type AuthInfo interface {
	AuthType() string
}

// ProtocolInfo provides information regarding the gRPC wire protocol version,
// security protocol, security protocol version in use, server name, etc.
type ProtocolInfo struct {
	// ProtocolVersion is the gRPC wire protocol version.
	ProtocolVersion string
	// SecurityProtocol is the security protocol in use.
	SecurityProtocol string
	// SecurityVersion is the security protocol version.  It is a static version string from the
	// credentials, not a value that reflects per-connection protocol negotiation.  To retrieve
	// details about the credentials used for a connection, use the Peer's AuthInfo field instead.
	//
	// Deprecated: please use Peer.AuthInfo.
	SecurityVersion string
	// ServerName is the user-configured server name.
	ServerName string
}

type TLSInfo struct {
	State tls.ConnectionState
}

func (t TLSInfo) AuthType() string {
	return "tls"
}

type tlsCred struct {
	//TLS Configuration
	config *tls.Config
}

func (c *tlsCred) Info() credentials.ProtocolInfo{
	return credentials.ProtocolInfo{
		SecurityProtocol: "tls",
		SecurityVersion: "1.2",
		ServerName: c.config.ServerName,
	}
}

func (c *tlsCred) ClientHandshake (ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error){
	cfg := cloneTLSConfig(c.config)
	if cfg.ServerName == "" {
		serverName, _, err := net.SplitHostPort(addr)
		if err != nil {
			// If the authority had no host port or if the authority cannot be parsed, use it as-is.
			serverName = addr
		}
		cfg.ServerName = serverName
	}
	conn := tls.Client(rawConn, cfg)
	gmcredsLogger.Debugf("remote address: ", conn.RemoteAddr().String())
	errChannel := make(chan error, 1)
	go func() {
		errChannel <- conn.Handshake()
		close(errChannel)
	}()
	select {
		case err := <-errChannel:
			if err !=nil {
				conn.Close()
				//gmcredsLogger.Errorf("Client TLS handshake failed with error %s", err)
				return nil, nil, err
			}
		case <- ctx.Done():
			conn.Close()
			//gmcredsLogger.Errorf("Client TLS handshake failed with ctx error %s", ctx.Err())
			return nil, nil, ctx.Err()
	}
	return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCred) ServerHandshake (rawConn net.Conn) (net.Conn, credentials.AuthInfo, error){
	conn := tls.Server(rawConn, c.config)
	gmcredsLogger.Debugf("remote address: ", conn.RemoteAddr().String())
	if err := conn.Handshake(); err != nil {
		gmcredsLogger.Errorf("Client TLS handshake failed with error %s", err)
		return nil, nil, err
	}
	return conn, TLSInfo{conn.ConnectionState()}, nil
}

func (c *tlsCred) OverrideServerName(serverNameOverride string) error {
	c.config.ServerName = serverNameOverride
	return nil
}

func NewTLS(c *tls.Config) credentials.TransportCredentials  {
	tc := &tlsCred{cloneTLSConfig(c)}
	tc.config.NextProtos = alpnProtoStr
	return tc
}

func (c *tlsCred) Clone() credentials.TransportCredentials {
	return NewTLS(c.config)
}
