/*
Copyright 2017 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main

import (
	"crypto/tls"
	"crypto/x509"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"time"

	"github.com/maps-booking/api"
	"github.com/maps-booking/utils"
	"golang.org/x/net/context"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"

	mpb "github.com/maps-booking/bookingservice"
	fpb "github.com/maps-booking/feeds"
)

const logFile = "grpc_test_client_log_"

/*
Below is a string representation of the client certificate and private key.
This is a self signed cert. To use this cert, maps-booking/certs/root.pem
must be added to your server.

Viewing the contents of this cert in a human readable format can be achieved using openssl.
The command and subsequent output are below for your convenience.

	openssl x509 -in .../maps-booking/certs/test_client_cert.pem -inform pem -noout -text

	Certificate:
			Data:
					Version: 3 (0x2)
					Serial Number: 1 (0x1)
			Signature Algorithm: sha256WithRSAEncryption
					Issuer: C=US, ST=California, L=Mountain View, O=Internet Widgits Pty Ltd, CN=maps-booking_test-client
					Validity
							Not Before: Oct 27 01:53:28 2017 GMT
							Not After : Oct 25 01:53:28 2027 GMT
					Subject: C=US, ST=California, O=Internet Widgits Pty Ltd, CN=maps-booking_test-client-cert
*/
const clientCert = `-----BEGIN CERTIFICATE-----
MIIE5DCCAsygAwIBAgIBATANBgkqhkiG9w0BAQsFADCBgDELMAkGA1UEBhMCVVMx
EzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDU1vdW50YWluIFZpZXcxITAf
BgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEhMB8GA1UEAwwYbWFwcy1i
b29raW5nX3Rlc3QtY2xpZW50MB4XDTE3MTAyNzAxNTMyOFoXDTI3MTAyNTAxNTMy
OFowbTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEmMCQGA1UEAwwdbWFwcy1ib29raW5n
X3Rlc3QtY2xpZW50LWNlcnQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
AQC4SoqntW+syzjdw5an2HkP4ZobuUr/i9eTZ5ub/JpHnKHE/xpAFAI3Pii7b4mv
ZlbEazRsKJc0LQ+3hblBQdwBxF/QPElH10fD5uhScOOpLz5M9wlX+cDLSrXhDiKM
9xPY9OlpydIUu/FbU8V+gIsibruCOUB92aMKTkZMwgMRbPIBHF4IsLODm2JQiyFF
EbJje3mVmQtHyLoWCwWU75wtIgFt7WheTSEw2IpotZTT3nQCqhhrYvUvcwYnj7CD
/18n7ClA/7nf3+f+1WkZX6GITvXEe3A8Rj0oes/cZlzZO0aVxmosnAQiIt2LWuis
Lrs//lQEhAl/rRl1ZHoE+LOLAgMBAAGjezB5MAkGA1UdEwQCMAAwLAYJYIZIAYb4
QgENBB8WHU9wZW5TU0wgR2VuZXJhdGVkIENlcnRpZmljYXRlMB0GA1UdDgQWBBSK
90Bvc8jEeKuD4ZgUak0QCi5byjAfBgNVHSMEGDAWgBSwnp+00yr82MUrKAIkoVMT
WtzoozANBgkqhkiG9w0BAQsFAAOCAgEAtJoFY0TVZP4SwnLaX6hBEbv3CvIhL/08
WSce1e48XKZCWnL+NmxFJu102QSWjXwpF1XsCEOUjhEWsOfmLQ0/z5Pieqsj9cfx
vfH8Q8BZkC9shykQrO0Hg/Ip+VsU3KSpkpehj0cfUuPuMFZWmPouZGnpyTMXGx/C
L7SeqBBPa3TlSpONr5IVrI0j4Agfv0+qga7DfQUY6jsxx8uVxl6jbpFBx7UXekTa
aIP+Wqq0H9x/yWYt+K6yWmb5THAx/pTNOUYGFiFnW56fUap1UgPEF0vNZHHYZycD
hePsd9jlomD3G38tBmM5qd5CH3OH9S2U2CnC1ifDP6DxxlNHPFCxk6YrISu5fEly
+kcbf3PNKJlqtIL4/GRYsD1e0eVNTMEh5i+qgGnEoErVnlDTM+xolf1DJzSA3uNO
P6HccNGDcV4TePvQ9KU3ACMtBCNBbLqQm1r7e8U6MzDmAUOXhDvXMv7pdLFNiJvd
uD/faRnhAt9AsRCz1HjYx7aNz5jpB7uZCrDlDmuFGxG8ieInQMrATUz+E8Yfzh/d
5k5U0jOgR+IjPKpaq6Z4ejwg4i9vW7EJ8o+Nk1oMB0IMjAJi/UOZl9zRANvcVxO5
T8dmsxdWt/l68bdjP+jY1vjQ98ebTfo+q1cTNDoO0nDJHOw/NDemxTPgUa1Rs5vK
5c2Q8YXBTiU=
-----END CERTIFICATE-----`
const clientKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAuEqKp7VvrMs43cOWp9h5D+GaG7lK/4vXk2ebm/yaR5yhxP8a
QBQCNz4ou2+Jr2ZWxGs0bCiXNC0Pt4W5QUHcAcRf0DxJR9dHw+boUnDjqS8+TPcJ
V/nAy0q14Q4ijPcT2PTpacnSFLvxW1PFfoCLIm67gjlAfdmjCk5GTMIDEWzyARxe
CLCzg5tiUIshRRGyY3t5lZkLR8i6FgsFlO+cLSIBbe1oXk0hMNiKaLWU0950AqoY
a2L1L3MGJ4+wg/9fJ+wpQP+539/n/tVpGV+hiE71xHtwPEY9KHrP3GZc2TtGlcZq
LJwEIiLdi1rorC67P/5UBIQJf60ZdWR6BPiziwIDAQABAoIBAQCbJy6ayTauzB0h
HxSMVMR/aVj8NEB+6rXgxN6OMdmVprnPB1KLVg0Tg0J5owrQ36D3FqZ41KeP5swP
nwZ7eT4HQtPDla3ATO9/b7xyA9a3Ti3uUCDOr1bwEAMV6XePJEjSZEbKqH40tJIb
aGih+wioQX+dwCOakIsiFwo6fzBkDtsj4V3qVunuZ5TRufUSrXebAMSsE5U+POch
FznfqsD/vhus5ggbhM/8wG1NgEtO5apkc+KjA/vnWpcpSeWxnBYmQZr0dU9iZJr+
vCxMKY+kdVfjohbeGQt51TqUkBdLbfUlun8fE1X9hrsA5aYonP82unJ+hWn3Wg60
fP9vr1gBAoGBAPWcsKoCG6DisgLAjNhAMC2Q7lp2qsM7Y01p4epLX8UxOHRMlpnr
erNRc4k75hZDp7/rtrO6eb3F+76mlPVbqi/JRRqY5xM8vsDF8KYSAW8XJXZ6AP+Z
G7sxjJ8QnlT1lAWK+XNxAEk6U3jj6eB+CfGSZJJh3tbjhX69BseE5uQBAoGBAMAV
6cpOq+F0ZSlh3sljvlDEPY0BaUtlAF5LOX6Vdg7c/zeWJtAAE24hPYwxSi9e4jYx
eHBM3Xk5yQVGsIbxl/dvgWGY2nbzb1YdbOEEzrVBh7hm8YW3DER+HwNcDpMhoWCf
O8uD59YNsyvrqD/vb7KpVDWr+yD8ClO7zjLL3ueLAoGAdAZMElOil5LfgptRLYrM
94mCf2uVaVqxo01EcnieyjlhMNdJQXbS5Miyan7IR3Y4VVpVWXvarMJNFRf+QBXI
RICwy0q1xgmpFsmqz9iror3tbZVeyV+bkQdsJWwlT38fKKspAda8ytrpua74uZrw
uZRtPBVNvneGhYNoI3Jt3AECgYB5bGDBdkHI3x8jra57eAXSYHrYK9A3zL0S3lKV
5j0e4CylItGeIq4lq/WQLYhLsZslztfnhW9rNlAQecMVSptZ2q7a1xkioHf849Tz
2Wohwi7dLpX2hOPIWEGaihLchyHQRlgyKkvfUAG2/dz5rY3aTpfg5bp1+1072Thb
e+yISQKBgGgtchLhw26f+wIvDDJx6AExEbbnV5V84riqZR4PpBz4e88xWdtpfYT7
J0nidjfbgcpfIysD8ELn/bGtNsYWrTkX8rqB5buBGf+WP7+ChWLwdm1GPtYJJN4s
ewhN35RpB6EjTdp9jp5ifTfSSBAoHN8yJeeHxU64HHb8nDdBwxmW
-----END RSA PRIVATE KEY-----`

var (
	serverAddr        = flag.String("server_addr", "example.com:80", "Your grpc server's address in the format of host:port")
	rpcTimeout        = flag.Duration("rpc_timeout", 30*time.Second, "Number of seconds to wait before abandoning request")
	testSlots         = flag.Int("num_test_slots", 10, "Maximum number of slots to test from availability_feed. Slots will be selected randomly")
	allFlows          = flag.Bool("all_tests", false, "Whether to test all endpoints.")
	healthFlow        = flag.Bool("health_check_test", false, "Whether to test the Health endpoint.")
	checkFlow         = flag.Bool("check_availability_test", false, "Whether to test the CheckAvailability endpoint.")
	bookFlow          = flag.Bool("booking_test", false, "Whether to test the CreateBooking endpoint.")
	listFlow          = flag.Bool("list_bookings_test", false, "Whether to test the ListBookings endpoint")
	statusFlow        = flag.Bool("booking_status_test", false, "Whether to test the GetBookingStatus endpoint.")
	rescheduleFlow    = flag.Bool("rescheduling_test", false, "Whether to test the UpdateBooking endpoint.")
	cancelAllBookings = flag.Bool("cancel_all_bookings", false, "This option assumes that the ListBookings and UpdateBooking endpoints are fully functional. This is a convenience flag for purging your system of all previously created bookings.")
	availabilityFeed  = flag.String("availability_feed", "", "Absolute path to availability feed required for all tests except health. Feeds can be in either json or pb3 format")
	outputDir         = flag.String("output_dir", "", "Absolute path of dir to dump log file.")
	enableTLS         = flag.Bool("tls", false, "Whether to enable TLS when using the test client. Please review the README.md before attempting to use this flag.")
	caFile            = flag.String("ca_file", "", "Absolute path to your server's Certificate Authority root cert. Downloading all roots currently recommended by the Google Internet Authority is a suitable alternative https://pki.google.com/roots.pem")
	serverName        = flag.String("servername_override", "", "Override FQDN to use. Please see README for additional details")
)

type counters struct {
	TotalSlotsProcessed      int
	HealthCheckSuccess       bool
	CheckAvailabilitySuccess int
	CheckAvailabilityErrors  int
	CreateBookingSuccess     int
	CreateBookingErrors      int
	ListBookingsSuccess      bool
	GetBookingStatusSuccess  int
	GetBookingStatusErrors   int
	CancelBookingsSuccess    int
	CancelBookingsErrors     int
	ReschedulingSuccess      bool
}

// GenerateBookings creates bookings from an availability feed.
func GenerateBookings(ctx context.Context, av []*fpb.Availability, stats *counters, c mpb.BookingServiceClient) api.Bookings {
	log.Println("no previous bookings to use, acquiring new inventory")
	utils.LogFlow("Generate Fresh Inventory", "Start")
	defer utils.LogFlow("Generate Fresh Inventory", "End")

	var out api.Bookings
	totalSlots := len(av)
	for i, a := range av {
		if err := api.CheckAvailability(ctx, a, c); err != nil {
			log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
			stats.CheckAvailabilityErrors++
			continue
		}
		stats.CheckAvailabilitySuccess++

		booking, err := api.CreateBooking(ctx, a, c)
		if err != nil {
			log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
			stats.CreateBookingErrors++
			continue
		}
		out = append(out, booking)
		stats.CreateBookingSuccess++
	}
	return out
}

func createLogFile() (*os.File, error) {
	var err error
	outPath := *outputDir
	if outPath == "" {
		outPath, err = os.Getwd()
		if err != nil {
			return nil, err
		}
	}

	now := time.Now().UTC()
	nowString := fmt.Sprintf("%d-%02d-%02d_%02d-%02d-%02d", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second())
	outFile := filepath.Join(outPath, fmt.Sprintf("%s%s", logFile, nowString))

	return os.Create(outFile)
}

func setTimeout(ctx context.Context) context.Context {
	tCtx, _ := context.WithTimeout(ctx, *rpcTimeout)
	return tCtx
}

func logStats(stats counters) {
	log.Println("\n************* Begin Stats *************\n")
	var totalErrors int
	if *healthFlow || *allFlows {
		if stats.HealthCheckSuccess {
			log.Println("HealthCheck Succeeded")
		} else {
			totalErrors++
			log.Println("HealthCheck Failed")
		}
	}
	if *checkFlow || *allFlows {
		totalErrors += stats.CheckAvailabilityErrors
		log.Printf("CheckAvailability Errors: %d/%d", stats.CheckAvailabilityErrors, stats.CheckAvailabilityErrors+stats.CheckAvailabilitySuccess)
	}
	if *bookFlow || *allFlows {
		totalErrors += stats.CreateBookingErrors
		log.Printf("CreateBooking Errors: %d/%d", stats.CreateBookingErrors, stats.CreateBookingErrors+stats.CreateBookingSuccess)
	}
	if *listFlow || *allFlows {
		if stats.ListBookingsSuccess {
			log.Println("ListBookings Succeeded")
		} else {
			totalErrors++
			log.Println("ListBookings Failed")
		}
	}
	if *statusFlow || *allFlows {
		totalErrors += stats.GetBookingStatusErrors
		log.Printf("GetBookingStatus Errors: %d/%d", stats.GetBookingStatusErrors, stats.GetBookingStatusErrors+stats.GetBookingStatusSuccess)
	}
	if *rescheduleFlow || *allFlows {
		if stats.ReschedulingSuccess {
			log.Println("Rescheduling Succeeded")
		} else {
			totalErrors++
			log.Println("Rescheduling Failed")
		}
	}

	log.Println("\n\n\n")
	if totalErrors == 0 {
		log.Println("All Tests Pass!")
	} else {
		log.Printf("Found %d Errors", totalErrors)
	}

	log.Println("\n************* End Stats *************\n")
	os.Exit(0)
}

func buildGrpcConnection() (*grpc.ClientConn, error) {
	var opts []grpc.DialOption
	if *enableTLS {
		cert, err := tls.X509KeyPair([]byte(clientCert), []byte(clientKey))
		if err != nil {
			return nil, err
		}
		b, err := ioutil.ReadFile(*caFile)
		if err != nil {
			return nil, fmt.Errorf("failed to read root certificates file: %v", err)
		}
		cp := x509.NewCertPool()
		if !cp.AppendCertsFromPEM(b) {
			return nil, errors.New("failed to parse root certificates, please check your roots file (ca_file flag) and try again")
		}
		config := &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: cp}
		if *serverName != "" {
			config.ServerName = *serverName
		}
		creds := credentials.NewTLS(config)
		opts = append(opts, grpc.WithTransportCredentials(creds))
	} else {
		opts = append(opts, grpc.WithInsecure())
	}
	return grpc.Dial(*serverAddr, opts...)
}

func main() {
	flag.Parse()
	var stats counters

	// Set up logging before continuing with flows
	f, err := createLogFile()
	if err != nil {
		log.Fatalf("Failed to create log file %v", err)
	}
	defer f.Close()
	log.SetOutput(f)

	conn, err := buildGrpcConnection()
	if err != nil {
		log.Fatalf("failed to dial: %v", err)
	}
	defer conn.Close()
	client := mpb.NewBookingServiceClient(conn)

	ctx := context.Background()

	// Health check doesn't affect the cancel booking flow so we let it through.
	if *cancelAllBookings && (*allFlows || *checkFlow || *bookFlow || *listFlow || *statusFlow || *rescheduleFlow) {
		log.Fatal("cancel_all_bookings is not supported with other test flows")
	}

	// HealthCheck Flow
	if *healthFlow || *allFlows {
		stats.HealthCheckSuccess = true
		if err = api.HealthCheck(setTimeout(ctx), conn); err != nil {
			stats.HealthCheckSuccess = false
			log.Println(err.Error())
		}
		if !*allFlows && !*checkFlow && !*bookFlow &&
			!*listFlow && !*statusFlow && !*rescheduleFlow {
			logStats(stats)
		}
	}

	var av []*fpb.Availability
	if !*cancelAllBookings {
		// Build availablility records.
		if *availabilityFeed == "" {
			log.Fatal("please set availability_feed flag if you wish to test additional flows")
		}
		av, err = utils.AvailabilityFrom(*availabilityFeed, *testSlots)
		if err != nil {
			log.Fatal(err.Error())
		}
		stats.TotalSlotsProcessed += len(av)
	}

	// AvailabilityCheck Flow
	if *checkFlow || *allFlows {
		utils.LogFlow("Availability Check", "Start")
		totalSlots := len(av)

		j := 0
		for i, a := range av {
			if err = api.CheckAvailability(setTimeout(ctx), a, client); err != nil {
				log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
				stats.CheckAvailabilityErrors++
				continue
			}
			stats.CheckAvailabilitySuccess++
			av[j] = a
			j++
		}
		av = av[:j]
		utils.LogFlow("Availability Check", "End")
	}

	// CreateBooking Flow.
	var b []*mpb.Booking
	if *bookFlow || *allFlows {
		utils.LogFlow("Booking", "Start")
		totalSlots := len(av)
		for i, a := range av {
			booking, err := api.CreateBooking(setTimeout(ctx), a, client)
			if err != nil {
				log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
				stats.CreateBookingErrors++
				continue
			}
			b = append(b, booking)
			stats.CreateBookingSuccess++
		}
		utils.LogFlow("Booking", "End")
	}

	// ListBookings Flow
	if *listFlow || *allFlows || *cancelAllBookings {
		if len(b) == 0 && !*cancelAllBookings {
			b = GenerateBookings(setTimeout(ctx), av, &stats, client)
		}
		utils.LogFlow("List Bookings", "Start")
		b, err = api.ListBookings(setTimeout(ctx), b, client)
		stats.ListBookingsSuccess = true
		if err != nil {
			stats.ListBookingsSuccess = false
			log.Println(err.Error())
		}
		utils.LogFlow("List Bookings", "End")
	}

	// GetBookingStatus Flow
	if *statusFlow || *allFlows {
		if len(b) == 0 {
			b = GenerateBookings(setTimeout(ctx), av, &stats, client)
		}

		utils.LogFlow("BookingStatus", "Start")
		totalBookings := len(b)

		j := 0
		for i, booking := range b {
			if err = api.GetBookingStatus(setTimeout(ctx), booking, client); err != nil {
				log.Printf("%s. abandoning booking %d/%d", err.Error(), i, totalBookings)
				stats.GetBookingStatusErrors++
				continue
			}
			stats.GetBookingStatusSuccess++
			b[j] = booking
			j++
		}
		b = b[:j]
		utils.LogFlow("BookingStatus", "End")
	}

	// CancelBooking Flow
	if len(b) > 0 {
		utils.LogFlow("Cancel Booking", "Start")
		for i, booking := range b {
			if err = api.CancelBooking(setTimeout(ctx), booking, client); err != nil {
				log.Printf("%s. abandoning booking %d/%d", err.Error(), i, len(b))
				stats.CancelBookingsErrors++
				continue
			}
			stats.CancelBookingsSuccess++
		}
		utils.LogFlow("Cancel Booking", "End")
	}

	// Rescheduling is nuanced and can be isolated
	// from the rest of the tests.
	if *rescheduleFlow || *allFlows {
		utils.LogFlow("Rescheduling", "Start")
		stats.ReschedulingSuccess = true
		if err = api.Rescheduling(setTimeout(ctx), av, client); err != nil {
			log.Println(err.Error())
			stats.ReschedulingSuccess = false
		}
		utils.LogFlow("Rescheduling", "End")
	}

	logStats(stats)
}
