代码
package filelisting
import (
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
)
const prefix = "/list/"
type userError string
func (e userError) Error() string {
return e.Message()
}
func (e userError) Message() string {
return string(e)
}
func HandlerFIleList(writer http.ResponseWriter, request *http.Request) error {
if strings.Index(request.URL.Path, prefix) != 0 {
return userError(fmt.Sprintf("path %s must start with %s", request.URL.Path, prefix))
}
path := request.URL.Path[len(prefix):]
file, err := os.Open(path)
defer file.Close()
if err != nil {
return err
}
bytes, err := ioutil.ReadAll(file)
if err != nil {
return err
}
writer.Write(bytes)
return nil
}
package main
import (
"learngo/errhandling/filelistingserver/filelisting"
"log"
"net/http"
"os"
)
type appHandler func(writer http.ResponseWriter, request *http.Request) error
type userError interface {
error
Message() string
}
func errWrapper(handler appHandler) func(writer http.ResponseWriter, request *http.Request) {
return func(writer http.ResponseWriter, request *http.Request) {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic %v", r)
http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
err := handler(writer, request)
code := http.StatusOK
if err != nil {
log.Printf("Error occurred handling request: %s",err.Error())
if userErr, ok := err.(userError); ok {
http.Error(writer,userErr.Message(),http.StatusBadRequest)
return
}
switch {
case os.IsNotExist(err):
code = http.StatusNotFound
case os.IsPermission(err):
code = http.StatusForbidden
default:
code = http.StatusInternalServerError
}
http.Error(writer, http.StatusText(code), code)
}
}
}
const prefix = "/list/"
func main() {
http.HandleFunc(prefix, errWrapper(filelisting.HandlerFIleList))
err := http.ListenAndServe(":8888", nil)
if err != nil {
panic(err)
}
}
使用假的Request/Response
package main
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
type testingUserError string
func (e testingUserError) Error() string {
return e.Message()
}
func (e testingUserError) Message() string {
return string(e)
}
func errPanic(http.ResponseWriter, *http.Request) error {
panic(123)
}
func errUserError(http.ResponseWriter, *http.Request) error {
return testingUserError("user Error")
}
func errNotFound(http.ResponseWriter, *http.Request) error {
return os.ErrNotExist
}
func errNoPermission(http.ResponseWriter, *http.Request) error {
return os.ErrPermission
}
func errUnknown(http.ResponseWriter, *http.Request) error {
return errors.New("unknown error")
}
func noError(writer http.ResponseWriter, request *http.Request) error {
fmt.Fprintln(writer, "no error")
return nil
}
func TestErrWrapper(t *testing.T) {
tests := []struct {
handler appHandler
code int
msg string
}{
{noError, 200, "no error"},
{errUserError,400, "user Error"},
{errPanic, 500, "Internal Server Error"},
{errNotFound, 404, "Not Found"},
{errNoPermission, 403, "Forbidden"},
{errUnknown, 500, "Internal Server Error"},
}
for _, test := range tests {
wrapper := errWrapper(test.handler)
response := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://www.baidu.com", nil)
wrapper(response, request)
bytes, _ := ioutil.ReadAll(response.Body)
body := strings.Trim(string(bytes), "\n")
if test.msg != body || response.Code != test.code {
t.Errorf("期望值:(%d,%s); 实际值:(%d,%s)", response.Code, body, test.code, test.msg)
}
}
}
起服务器
package main
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
type testingUserError string
func (e testingUserError) Error() string {
return e.Message()
}
func (e testingUserError) Message() string {
return string(e)
}
func errPanic(http.ResponseWriter, *http.Request) error {
panic(123)
}
func errUserError(http.ResponseWriter, *http.Request) error {
return testingUserError("user Error")
}
func errNotFound(http.ResponseWriter, *http.Request) error {
return os.ErrNotExist
}
func errNoPermission(http.ResponseWriter, *http.Request) error {
return os.ErrPermission
}
func errUnknown(http.ResponseWriter, *http.Request) error {
return errors.New("unknown error")
}
func noError(writer http.ResponseWriter, request *http.Request) error {
fmt.Fprintln(writer, "no error")
return nil
}
func TestErrWrapprtInServer(t *testing.T) {
tests := []struct {
handler appHandler
code int
msg string
}{
{noError, 200, "no error"},
{errUserError,400, "user Error"},
{errPanic, 500, "Internal Server Error"},
{errNotFound, 404, "Not Found"},
{errNoPermission, 403, "Forbidden"},
{errUnknown, 500, "Internal Server Error"},
}
for _, test := range tests {
wrapper := errWrapper(test.handler)
server := httptest.NewServer(http.HandlerFunc(wrapper))
response, _ := http.Get(server.URL)
bytes, _ := ioutil.ReadAll(response.Body)
body := strings.Trim(string(bytes), "\n")
if response.StatusCode != test.code || body != test.msg {
t.Errorf("期望值(%d,%s); 实际值:(%d,%s)", test.code, test.msg, response.StatusCode, body)
}
}
}
将重复代码抽取公共函数
package main
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
type testingUserError string
func (e testingUserError) Error() string {
return e.Message()
}
func (e testingUserError) Message() string {
return string(e)
}
func errPanic(http.ResponseWriter, *http.Request) error {
panic(123)
}
func errUserError(http.ResponseWriter, *http.Request) error {
return testingUserError("user Error")
}
func errNotFound(http.ResponseWriter, *http.Request) error {
return os.ErrNotExist
}
func errNoPermission(http.ResponseWriter, *http.Request) error {
return os.ErrPermission
}
func errUnknown(http.ResponseWriter, *http.Request) error {
return errors.New("unknown error")
}
func noError(writer http.ResponseWriter, request *http.Request) error {
fmt.Fprintln(writer, "no error")
return nil
}
var tests = []struct {
handler appHandler
code int
msg string
}{
{noError, 200, "no error"},
{errUserError,400, "user Error"},
{errPanic, 500, "Internal Server Error"},
{errNotFound, 404, "Not Found"},
{errNoPermission, 403, "Forbidden"},
{errUnknown, 500, "Internal Server Error"},
}
func TestErrWrapper(t *testing.T) {
for _, test := range tests {
wrapper := errWrapper(test.handler)
response := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "http://www.baidu.com", nil)
wrapper(response, request)
checkResponse(t, response.Result(), test.code, test.msg)
}
}
func TestErrWrapprtInServer(t *testing.T) {
for _, test := range tests {
wrapper := errWrapper(test.handler)
server := httptest.NewServer(http.HandlerFunc(wrapper))
response, _ := http.Get(server.URL)
checkResponse(t, response, test.code, test.msg)
}
}
func checkResponse(t *testing.T, response *http.Response, expectedCode int, expectedMsg string) {
bytes, _ := ioutil.ReadAll(response.Body)
body := strings.Trim(string(bytes), "\n")
if response.StatusCode != expectedCode || body != expectedMsg {
t.Errorf("期望值(%d,%s); 实际值:(%d,%s)", expectedCode, expectedMsg, response.StatusCode, body)
}
}
使用假的Request/Response VS 起服务器
-
使用假的Request/Response
- 速度快
- 更加像一个单元测试,只是测试一个小的函数
-
起服务器
- 集成度比较高,测试到的代码的覆盖量比较大
- 速度慢