Go学习:测试http服务器

代码

package filelisting

import (
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"strings"
)

const prefix = "/list/"

type userError string

func (e userError) Error() string {
	return e.Message()
}

func (e userError) Message() string {
	return string(e)
}

func HandlerFIleList(writer http.ResponseWriter, request *http.Request) error {
	if strings.Index(request.URL.Path, prefix) != 0 {
		return userError(fmt.Sprintf("path %s must start with %s", request.URL.Path, prefix))
	}
	path := request.URL.Path[len(prefix):]
	file, err := os.Open(path)
	defer file.Close()
	if err != nil {
		return err
	}
	bytes, err := ioutil.ReadAll(file)
	if err != nil {
		return err
	}
	writer.Write(bytes)
	return nil
}

package main

import (
	"learngo/errhandling/filelistingserver/filelisting"
	"log"
	"net/http"
	"os"
)

type appHandler func(writer http.ResponseWriter, request *http.Request) error

type userError interface {
	error
	Message() string
}

func errWrapper(handler appHandler) func(writer http.ResponseWriter, request *http.Request) {
	return func(writer http.ResponseWriter, request *http.Request) {
		defer func() {
			if r := recover(); r != nil {
				log.Printf("Panic %v", r)
				http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
			}
		}()
		err := handler(writer, request)
		code := http.StatusOK
		if err != nil {
			log.Printf("Error occurred handling request: %s",err.Error())
			if userErr, ok := err.(userError); ok {
				http.Error(writer,userErr.Message(),http.StatusBadRequest)
				return
			}
			switch {
			case os.IsNotExist(err):
				code = http.StatusNotFound
			case os.IsPermission(err):
				code = http.StatusForbidden
			default:
				code = http.StatusInternalServerError
			}
			http.Error(writer, http.StatusText(code), code)
		}
	}
}

const prefix = "/list/"

func main() {
	http.HandleFunc(prefix, errWrapper(filelisting.HandlerFIleList))
	err := http.ListenAndServe(":8888", nil)
	if err != nil {
		panic(err)
	}
}


使用假的Request/Response

package main

import (
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
)

type testingUserError string

func (e testingUserError) Error() string {
	return e.Message()
}

func (e testingUserError) Message() string {
	return string(e)
}

func errPanic(http.ResponseWriter, *http.Request) error {
	panic(123)
}

func errUserError(http.ResponseWriter, *http.Request) error {
	return testingUserError("user Error")
}

func errNotFound(http.ResponseWriter, *http.Request) error {
	return os.ErrNotExist
}

func errNoPermission(http.ResponseWriter, *http.Request) error {
	return os.ErrPermission
}

func errUnknown(http.ResponseWriter, *http.Request) error {
	return errors.New("unknown error")
}

func noError(writer http.ResponseWriter, request *http.Request) error {
	fmt.Fprintln(writer, "no error")
	return nil
}

func TestErrWrapper(t *testing.T) {
	tests := []struct {
		handler appHandler
		code    int
		msg     string
	}{
		{noError, 200, "no error"},
		{errUserError,400, "user Error"},
		{errPanic, 500, "Internal Server Error"},
		{errNotFound, 404, "Not Found"},
		{errNoPermission, 403, "Forbidden"},
		{errUnknown, 500, "Internal Server Error"},
	}

	for _, test := range tests {
		wrapper := errWrapper(test.handler)
		response := httptest.NewRecorder()
		request := httptest.NewRequest(http.MethodGet, "http://www.baidu.com", nil)
		wrapper(response, request)
		bytes, _ := ioutil.ReadAll(response.Body)
		body := strings.Trim(string(bytes), "\n")
		if test.msg != body || response.Code != test.code {
			t.Errorf("期望值:(%d,%s); 实际值:(%d,%s)", response.Code, body, test.code, test.msg)
		}
	}
}

在这里插入图片描述

起服务器

package main

import (
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
)

type testingUserError string

func (e testingUserError) Error() string {
	return e.Message()
}

func (e testingUserError) Message() string {
	return string(e)
}

func errPanic(http.ResponseWriter, *http.Request) error {
	panic(123)
}

func errUserError(http.ResponseWriter, *http.Request) error {
	return testingUserError("user Error")
}

func errNotFound(http.ResponseWriter, *http.Request) error {
	return os.ErrNotExist
}

func errNoPermission(http.ResponseWriter, *http.Request) error {
	return os.ErrPermission
}

func errUnknown(http.ResponseWriter, *http.Request) error {
	return errors.New("unknown error")
}

func noError(writer http.ResponseWriter, request *http.Request) error {
	fmt.Fprintln(writer, "no error")
	return nil
}

func TestErrWrapprtInServer(t *testing.T) {
	tests := []struct {
		handler appHandler
		code    int
		msg     string
	}{
		{noError, 200, "no error"},
		{errUserError,400, "user Error"},
		{errPanic, 500, "Internal Server Error"},
		{errNotFound, 404, "Not Found"},
		{errNoPermission, 403, "Forbidden"},
		{errUnknown, 500, "Internal Server Error"},
	}

	for _, test := range tests {
		wrapper := errWrapper(test.handler)
		server := httptest.NewServer(http.HandlerFunc(wrapper))
		response, _ := http.Get(server.URL)
		bytes, _ := ioutil.ReadAll(response.Body)
		body := strings.Trim(string(bytes), "\n")
		if response.StatusCode != test.code || body != test.msg {
			t.Errorf("期望值(%d,%s); 实际值:(%d,%s)", test.code, test.msg, response.StatusCode, body)
		}
	}
}

在这里插入图片描述

将重复代码抽取公共函数

package main

import (
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
)

type testingUserError string

func (e testingUserError) Error() string {
	return e.Message()
}

func (e testingUserError) Message() string {
	return string(e)
}

func errPanic(http.ResponseWriter, *http.Request) error {
	panic(123)
}

func errUserError(http.ResponseWriter, *http.Request) error {
	return testingUserError("user Error")
}

func errNotFound(http.ResponseWriter, *http.Request) error {
	return os.ErrNotExist
}

func errNoPermission(http.ResponseWriter, *http.Request) error {
	return os.ErrPermission
}

func errUnknown(http.ResponseWriter, *http.Request) error {
	return errors.New("unknown error")
}

func noError(writer http.ResponseWriter, request *http.Request) error {
	fmt.Fprintln(writer, "no error")
	return nil
}

var tests = []struct {
	handler appHandler
	code    int
	msg     string
}{
	{noError, 200, "no error"},
	{errUserError,400, "user Error"},
	{errPanic, 500, "Internal Server Error"},
	{errNotFound, 404, "Not Found"},
	{errNoPermission, 403, "Forbidden"},
	{errUnknown, 500, "Internal Server Error"},
}

func TestErrWrapper(t *testing.T) {
	for _, test := range tests {
		wrapper := errWrapper(test.handler)
		response := httptest.NewRecorder()
		request := httptest.NewRequest(http.MethodGet, "http://www.baidu.com", nil)
		wrapper(response, request)
		checkResponse(t, response.Result(), test.code, test.msg)
	}
}

func TestErrWrapprtInServer(t *testing.T) {

	for _, test := range tests {
		wrapper := errWrapper(test.handler)
		server := httptest.NewServer(http.HandlerFunc(wrapper))
		response, _ := http.Get(server.URL)
		checkResponse(t, response, test.code, test.msg)
	}
}

func checkResponse(t *testing.T, response *http.Response, expectedCode int, expectedMsg string) {
	bytes, _ := ioutil.ReadAll(response.Body)
	body := strings.Trim(string(bytes), "\n")
	if response.StatusCode != expectedCode || body != expectedMsg {
		t.Errorf("期望值(%d,%s); 实际值:(%d,%s)", expectedCode, expectedMsg, response.StatusCode, body)
	}
}

使用假的Request/Response VS 起服务器

  • 使用假的Request/Response

    • 速度快
    • 更加像一个单元测试,只是测试一个小的函数
  • 起服务器

    • 集成度比较高,测试到的代码的覆盖量比较大
    • 速度慢
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

.番茄炒蛋

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值