Optionally add cert_file to connect via https (and with that optionally add server name if it differs from the base URL)
diff --git a/README.md b/README.md index 1db5501..619efd8 100644 --- a/README.md +++ b/README.md
@@ -116,6 +116,10 @@ (default "example.com:80") -username_password string (eg 'username:password') credentials for your server. Leave blank to bypass authentication. + -ca_file string + Absolute path to your server's Certificate Authority root cert. Downloading all roots currently recommended by the Google Internet Authority is a suitable alternative https://pki.google.com/roots.pem. Leave blank to connect using http rather than https. + -full_server_name string + Fully qualified domain name. Same name used to sign CN. Only necessary if ca_file is specified and the base URL differs from the server address. Example Usage:
diff --git a/api/api.go b/api/api.go index 5cca4d8..f4cf08b 100644 --- a/api/api.go +++ b/api/api.go
@@ -19,6 +19,8 @@ import ( "bytes" + "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -56,24 +58,48 @@ baseURL string } +func setupCertConfig(caFile string, fullServerName string) (*tls.Config, error) { + if caFile == "" { + return nil, nil + } + b, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read root certificates file: %v", err) + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(b) { + return nil, errors.New("failed to parse root certificates, please check your roots file (ca_file flag) and try again") + } + return &tls.Config{ + RootCAs: cp, + ServerName: fullServerName, + }, nil +} + // InitHTTPConnection creates and returns a new HTTPConnection object // with a given server address and username/password. -func InitHTTPConnection(serverAddr string, credentialsFile string) (*HTTPConnection, error) { - usernamePassword := "" - if credentialsFile != "" { - // Set up username/password. - data, err := ioutil.ReadFile(credentialsFile) - if err != nil { - return nil, err - } - usernamePassword = strings.Replace(string(data), "\n", "", -1) +func InitHTTPConnection(serverAddr string, credentialsFile string, caFile string, fullServerName string) (*HTTPConnection, error) { + // Set up username/password. + data, err := ioutil.ReadFile(credentialsFile) + if err != nil { + return nil, err + } + usernamePassword := strings.Replace(string(data), "\n", "", -1) + config, err := setupCertConfig(caFile, fullServerName) + if err != nil { + return nil, err + } + protocol := "http" + if config != nil { + protocol = "https" } return &HTTPConnection{ - client: &http.Client{Timeout: time.Duration(1 * time.Second)}, + client: &http.Client{ + Transport: &http.Transport{TLSClientConfig: config}, + }, credentials: "Basic " + base64.StdEncoding.EncodeToString([]byte(usernamePassword)), marshaler: &jsonpb.Marshaler{OrigName: true}, - // TODO(wsilberm): Use https - baseURL: "http://" + serverAddr, + baseURL: protocol + "://" + serverAddr, }, nil } @@ -105,7 +131,7 @@ defer utils.LogFlow("Health Check", "End") // See if we get a response. - _, err := http.Get(conn.getURL("")) + _, err := conn.client.Get(conn.getURL("")) if err != nil { return fmt.Errorf("could not complete health check: %v", err) }
diff --git a/testclient/main.go b/testclient/main.go index b7610e2..c85de5f 100644 --- a/testclient/main.go +++ b/testclient/main.go
@@ -34,7 +34,7 @@ var ( serverAddr = flag.String("server_addr", "example.com:80", "Your http server's address in the format of host:port") - credentialsFile = flag.String("credentials_file", "", "File containing credentials for your server. Leave blank to bypass authentication.") + credentialsFile = flag.String("credentials_file", "", "File containing credentials for your server. Leave blank to bypass authentication. File should have exactly one line of the form 'username:password'.") testSlots = flag.Int("num_test_slots", 10, "Maximum number of slots to test from availability_feed. Slots will be selected randomly") allFlows = flag.Bool("all_tests", false, "Whether to test all endpoints.") healthFlow = flag.Bool("health_check_test", false, "Whether to test the Health endpoint.") @@ -46,6 +46,8 @@ cancelAllBookings = flag.Bool("cancel_all_bookings", false, "This option assumes that the ListBookings and UpdateBooking endpoints are fully functional. This is a convenience flag for purging your system of all previously created bookings.") availabilityFeed = flag.String("availability_feed", "", "Absolute path to availability feed required for all tests except health. Feeds can be in either json or pb3 format") outputDir = flag.String("output_dir", "", "Absolute path of dir to dump log file.") + caFile = flag.String("ca_file", "", "Absolute path to your server's Certificate Authority root cert. Downloading all roots currently recommended by the Google Internet Authority is a suitable alternative https://pki.google.com/roots.pem. Leave blank to connect using http rather than https.") + fullServerName = flag.String("full_server_name", "", "Fully qualified domain name. Same name used to sign CN. Only necessary if ca_file is specified and the base URL differs from the server address.") ) type counters struct { @@ -171,7 +173,7 @@ defer f.Close() log.SetOutput(f) - conn, err := api.InitHTTPConnection(*serverAddr, *credentialsFile) + conn, err := api.InitHTTPConnection(*serverAddr, *credentialsFile, *caFile, *fullServerName) if err != nil { log.Fatalf("Failed to init http connection %v", err) }