Optionally add cert_file to connect via https (and with that optionally add server name if it differs from the base URL)
diff --git a/README.md b/README.md
index 1db5501..619efd8 100644
--- a/README.md
+++ b/README.md
@@ -116,6 +116,10 @@
(default "example.com:80")
-username_password string
(eg 'username:password') credentials for your server. Leave blank to bypass authentication.
+ -ca_file string
+ Absolute path to your server's Certificate Authority root cert. Downloading all roots currently recommended by the Google Internet Authority is a suitable alternative https://pki.google.com/roots.pem. Leave blank to connect using http rather than https.
+ -full_server_name string
+ Fully qualified domain name. Same name used to sign CN. Only necessary if ca_file is specified and the base URL differs from the server address.
Example Usage:
diff --git a/api/api.go b/api/api.go
index 5cca4d8..f4cf08b 100644
--- a/api/api.go
+++ b/api/api.go
@@ -19,6 +19,8 @@
import (
"bytes"
+ "crypto/tls"
+ "crypto/x509"
"encoding/base64"
"errors"
"fmt"
@@ -56,24 +58,48 @@
baseURL string
}
+func setupCertConfig(caFile string, fullServerName string) (*tls.Config, error) {
+ if caFile == "" {
+ return nil, nil
+ }
+ b, err := ioutil.ReadFile(caFile)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read root certificates file: %v", err)
+ }
+ cp := x509.NewCertPool()
+ if !cp.AppendCertsFromPEM(b) {
+ return nil, errors.New("failed to parse root certificates, please check your roots file (ca_file flag) and try again")
+ }
+ return &tls.Config{
+ RootCAs: cp,
+ ServerName: fullServerName,
+ }, nil
+}
+
// InitHTTPConnection creates and returns a new HTTPConnection object
// with a given server address and username/password.
-func InitHTTPConnection(serverAddr string, credentialsFile string) (*HTTPConnection, error) {
- usernamePassword := ""
- if credentialsFile != "" {
- // Set up username/password.
- data, err := ioutil.ReadFile(credentialsFile)
- if err != nil {
- return nil, err
- }
- usernamePassword = strings.Replace(string(data), "\n", "", -1)
+func InitHTTPConnection(serverAddr string, credentialsFile string, caFile string, fullServerName string) (*HTTPConnection, error) {
+ // Set up username/password.
+ data, err := ioutil.ReadFile(credentialsFile)
+ if err != nil {
+ return nil, err
+ }
+ usernamePassword := strings.Replace(string(data), "\n", "", -1)
+ config, err := setupCertConfig(caFile, fullServerName)
+ if err != nil {
+ return nil, err
+ }
+ protocol := "http"
+ if config != nil {
+ protocol = "https"
}
return &HTTPConnection{
- client: &http.Client{Timeout: time.Duration(1 * time.Second)},
+ client: &http.Client{
+ Transport: &http.Transport{TLSClientConfig: config},
+ },
credentials: "Basic " + base64.StdEncoding.EncodeToString([]byte(usernamePassword)),
marshaler: &jsonpb.Marshaler{OrigName: true},
- // TODO(wsilberm): Use https
- baseURL: "http://" + serverAddr,
+ baseURL: protocol + "://" + serverAddr,
}, nil
}
@@ -105,7 +131,7 @@
defer utils.LogFlow("Health Check", "End")
// See if we get a response.
- _, err := http.Get(conn.getURL(""))
+ _, err := conn.client.Get(conn.getURL(""))
if err != nil {
return fmt.Errorf("could not complete health check: %v", err)
}
diff --git a/testclient/main.go b/testclient/main.go
index b7610e2..c85de5f 100644
--- a/testclient/main.go
+++ b/testclient/main.go
@@ -34,7 +34,7 @@
var (
serverAddr = flag.String("server_addr", "example.com:80", "Your http server's address in the format of host:port")
- credentialsFile = flag.String("credentials_file", "", "File containing credentials for your server. Leave blank to bypass authentication.")
+ credentialsFile = flag.String("credentials_file", "", "File containing credentials for your server. Leave blank to bypass authentication. File should have exactly one line of the form 'username:password'.")
testSlots = flag.Int("num_test_slots", 10, "Maximum number of slots to test from availability_feed. Slots will be selected randomly")
allFlows = flag.Bool("all_tests", false, "Whether to test all endpoints.")
healthFlow = flag.Bool("health_check_test", false, "Whether to test the Health endpoint.")
@@ -46,6 +46,8 @@
cancelAllBookings = flag.Bool("cancel_all_bookings", false, "This option assumes that the ListBookings and UpdateBooking endpoints are fully functional. This is a convenience flag for purging your system of all previously created bookings.")
availabilityFeed = flag.String("availability_feed", "", "Absolute path to availability feed required for all tests except health. Feeds can be in either json or pb3 format")
outputDir = flag.String("output_dir", "", "Absolute path of dir to dump log file.")
+ caFile = flag.String("ca_file", "", "Absolute path to your server's Certificate Authority root cert. Downloading all roots currently recommended by the Google Internet Authority is a suitable alternative https://pki.google.com/roots.pem. Leave blank to connect using http rather than https.")
+ fullServerName = flag.String("full_server_name", "", "Fully qualified domain name. Same name used to sign CN. Only necessary if ca_file is specified and the base URL differs from the server address.")
)
type counters struct {
@@ -171,7 +173,7 @@
defer f.Close()
log.SetOutput(f)
- conn, err := api.InitHTTPConnection(*serverAddr, *credentialsFile)
+ conn, err := api.InitHTTPConnection(*serverAddr, *credentialsFile, *caFile, *fullServerName)
if err != nil {
log.Fatalf("Failed to init http connection %v", err)
}