/*
Copyright 2017 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main

import (
	"flag"
	"fmt"
	"log"
	"os"
	"path/filepath"
	"time"

	"github.com/maps-booking/api"
	"github.com/maps-booking/utils"
	"golang.org/x/net/context"
	"google.golang.org/grpc"

	mpb "github.com/maps-booking/bookingservice"
	fpb "github.com/maps-booking/feeds"
)

const logFile = "grpc_test_client_log_"

var (
	serverAddr       = flag.String("server_addr", "example.com:80", "Your grpc server's address in the format of host:port")
	rpcTimeout       = flag.Duration("rpc_timeout", 30*time.Second, "Number of seconds to wait before abandoning request")
	testSlots        = flag.Int("num_test_slots", 10, "Maximum number of slots to test from availability_feed. Slots will be selected randomly")
	allFlows         = flag.Bool("all_tests", false, "Whether to test all endpoints.")
	healthFlow       = flag.Bool("health_check_test", false, "Whether to test the Health endpoint.")
	checkFlow        = flag.Bool("check_availability_test", false, "Whether to test the CheckAvailability endpoint.")
	bookFlow         = flag.Bool("booking_test", false, "Whether to test the CreateBooking endpoint.")
	listFlow         = flag.Bool("list_bookings_test", false, "Whether to test the ListBookings endpoint")
	statusFlow       = flag.Bool("booking_status_test", false, "Whether to test the GetBookingStatus endpoint.")
	rescheduleFlow   = flag.Bool("rescheduling_test", false, "Whether to test the UpdateBooking endpoint.")
	availabilityFeed = flag.String("availability_feed", "", "Absolute path to availability feed required for all tests except health. Feeds can be in either json or pb3 format")
	outputDir        = flag.String("output_dir", "", "Absolute path of dir to dump log file.")
)

type counters struct {
	TotalSlotsProcessed      int
	HealthCheckSuccess       bool
	CheckAvailabilitySuccess int
	CheckAvailabilityErrors  int
	CreateBookingSuccess     int
	CreateBookingErrors      int
	ListBookingsSuccess      bool
	GetBookingStatusSuccess  int
	GetBookingStatusErrors   int
	CancelBookingsSuccess    int
	CancelBookingsErrors     int
	ReschedulingSuccess      bool
}

// GenerateBookings creates bookings from an availability feed.
func GenerateBookings(ctx context.Context, av []*fpb.Availability, stats *counters, c mpb.BookingServiceClient) api.Bookings {
	log.Println("no previous bookings to use, acquiring new inventory")
	utils.LogFlow("Generate Fresh Inventory", "Start")
	defer utils.LogFlow("Generate Fresh Inventory", "End")

	var out api.Bookings
	totalSlots := len(av)
	for i, a := range av {
		if err := api.CheckAvailability(ctx, a, c); err != nil {
			log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
			stats.CheckAvailabilityErrors++
			continue
		}
		stats.CheckAvailabilitySuccess++

		booking, err := api.CreateBooking(ctx, a, c)
		if err != nil {
			log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
			stats.CreateBookingErrors++
			continue
		}
		out = append(out, booking)
		stats.CreateBookingSuccess++
	}
	return out
}

func createLogFile() (*os.File, error) {
	var err error
	outPath := *outputDir
	if outPath == "" {
		outPath, err = os.Getwd()
		if err != nil {
			return nil, err
		}
	}
	
	now := time.Now().UTC()
	nowString := fmt.Sprintf("%d-%02d-%02d_%02d-%02d-%02d", now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second())
	outFile := filepath.Join(outPath, fmt.Sprintf("%s%s", logFile, nowString))	

	return os.Create(outFile)
}

func setTimeout(ctx context.Context) context.Context {
	tCtx, _ := context.WithTimeout(ctx, *rpcTimeout)
	return tCtx
}

func logStats(stats counters) {
	log.Println("\n************* Begin Stats *************\n")
	var totalErrors int
	if *healthFlow || *allFlows {
		if stats.HealthCheckSuccess {
			log.Println("HealthCheck Succeeded")
		} else {
			totalErrors++
			log.Println("HealthCheck Failed")
		}
	}
	if *checkFlow || *allFlows {
		totalErrors += stats.CheckAvailabilityErrors
		log.Printf("CheckAvailability Errors: %d/%d", stats.CheckAvailabilityErrors, stats.CheckAvailabilityErrors+stats.CheckAvailabilitySuccess)
	}
	if *bookFlow || *allFlows {
		totalErrors += stats.CreateBookingErrors
		log.Printf("CreateBooking Errors: %d/%d", stats.CreateBookingErrors, stats.CreateBookingErrors+stats.CreateBookingSuccess)
	}
	if *listFlow || *allFlows {
		if stats.ListBookingsSuccess {
			log.Println("ListBookings Succeeded")
		} else {
			totalErrors++
			log.Println("ListBookings Failed")
		}
	}
	if *statusFlow || *allFlows {
		totalErrors += stats.GetBookingStatusErrors
		log.Printf("GetBookingStatus Errors: %d/%d", stats.GetBookingStatusErrors, stats.GetBookingStatusErrors+stats.GetBookingStatusSuccess)
	}
	if *rescheduleFlow || *allFlows {
		if stats.ReschedulingSuccess {
			log.Println("Rescheduling Succeeded")
		} else {
			totalErrors++
			log.Println("Rescheduling Failed")
		}
	}

	log.Println("\n\n\n")
	if totalErrors == 0 {
		log.Println("All Tests Pass!")
	} else {
		log.Printf("Found %d Errors", totalErrors)
	}

	log.Println("\n************* End Stats *************\n")
	os.Exit(0)
}

func main() {
	flag.Parse()
	var stats counters

	// Set up logging before continuing with flows
	f, err := createLogFile()
	if err != nil {
		log.Fatalf("Failed to create log file %v", err)
	}
	defer f.Close()
	log.SetOutput(f)

	opts := append([]grpc.DialOption{}, grpc.WithInsecure())
	conn, err := grpc.Dial(*serverAddr, opts...)
	if err != nil {
		log.Fatalf("fail to dial: %v", err)
	}
	defer conn.Close()
	client := mpb.NewBookingServiceClient(conn)

	ctx := context.Background()

	// HealthCheck Flow
	if *healthFlow || *allFlows {
		stats.HealthCheckSuccess = true
		if err = api.HealthCheck(setTimeout(ctx), conn); err != nil {
			stats.HealthCheckSuccess = false
			log.Println(err.Error())
		}
		if !*allFlows && !*checkFlow && !*bookFlow &&
			!*listFlow && !*statusFlow && !*rescheduleFlow {
			logStats(stats)
		}
	}

	// Build availablility records.
	if *availabilityFeed == "" {
		log.Fatal("please set availability_feed flag if you wish to test additional flows")
	}
	av, err := utils.AvailabilityFrom(*availabilityFeed, *testSlots)
	if err != nil {
		log.Fatal(err.Error())
	}
	stats.TotalSlotsProcessed += len(av)

	// AvailabilityCheck Flow
	if *checkFlow || *allFlows {
		utils.LogFlow("Availability Check", "Start")
		totalSlots := len(av)

		j := 0
		for i, a := range av {
			if err = api.CheckAvailability(setTimeout(ctx), a, client); err != nil {
				log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
				stats.CheckAvailabilityErrors++
				continue
			}
			stats.CheckAvailabilitySuccess++
			av[j] = a
			j++
		}
		av = av[:j]
		utils.LogFlow("Availability Check", "End")
	}

	// CreateBooking Flow.
	var b []*mpb.Booking
	if *bookFlow || *allFlows {
		utils.LogFlow("Booking", "Start")
		totalSlots := len(av)
		for i, a := range av {
			booking, err := api.CreateBooking(setTimeout(ctx), a, client)
			if err != nil {
				log.Printf("%s. skipping slot %d/%d", err.Error(), i, totalSlots)
				stats.CreateBookingErrors++
				continue
			}
			b = append(b, booking)
			stats.CreateBookingSuccess++
		}
		utils.LogFlow("Booking", "End")
	}

	// ListBookings Flow
	if *listFlow || *allFlows {
		if len(b) == 0 {
			b = GenerateBookings(setTimeout(ctx), av, &stats, client)
		}
		utils.LogFlow("List Bookings", "Start")
		b, err = api.ListBookings(setTimeout(ctx), b, client)
		stats.ListBookingsSuccess = true
		if err != nil {
			stats.ListBookingsSuccess = false
			log.Println(err.Error())
		}
		utils.LogFlow("List Bookings", "End")
	}

	// GetBookingStatus Flow
	if *statusFlow || *allFlows {
		if len(b) == 0 {
			b = GenerateBookings(setTimeout(ctx), av, &stats, client)
		}

		utils.LogFlow("BookingStatus", "Start")
		totalBookings := len(b)

		j := 0
		for i, booking := range b {
			if err = api.GetBookingStatus(setTimeout(ctx), booking, client); err != nil {
				log.Printf("%s. abandoning booking %d/%d", err.Error(), i, totalBookings)
				stats.GetBookingStatusErrors++
				continue
			}
			stats.GetBookingStatusSuccess++
			b[j] = booking
			j++
		}
		b = b[:j]
		utils.LogFlow("BookingStatus", "End")
	}

	// CancelBooking Flow
	if len(b) > 0 {
		utils.LogFlow("Cancel Booking", "Start")
		for i, booking := range b {
			if err = api.CancelBooking(setTimeout(ctx), booking, client); err != nil {
				log.Printf("%s. abandoning booking %d/%d", err.Error(), i, len(b))
				stats.CancelBookingsErrors++
				continue
			}
			stats.CancelBookingsSuccess++
		}
		utils.LogFlow("Cancel Booking", "End")
	}

	// Rescheduling is nuanced and can be isolated
	// from the rest of the tests.
	if *rescheduleFlow || *allFlows {
		utils.LogFlow("Rescheduling", "Start")
		stats.ReschedulingSuccess = true
		if err = api.Rescheduling(setTimeout(ctx), av, client); err != nil {
			log.Println(err.Error())
			stats.ReschedulingSuccess = false
		}
		utils.LogFlow("Rescheduling", "End")
	}

	logStats(stats)
}
