将之前用文件写的改成了数据库,凭感觉随便加了几把锁也不知道对不对=_=||
主要拿来练习数据库的预习实践了,自己要先建立一个名叫chat
的空数据库
还有多重登陆问题暂时不想解决了。。。
还有很多功能可以加,比如查看聊天记录,删除聊天记录等。。。
server.go
package main
import (
"bufio"
"database/sql"
"fmt"
"log"
"net"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
var mu sync.RWMutex
func help() string {
return (`
please choose options:
- register : 注册。 格式:"register"
- login : 登录。 格式:"login"
- logoff : 注销登录。 格式:"logoff"
- exit : 退出系统。 格式:"exit"
- add : 添加好友。 格式:"add"
- delete : 删除好友。 格式:"delete"
- list : 查看好友列表。 格式:"list"
- sendTo : 给某位好友发送消息。 格式:"sendTo"
- sendAll : 给全部好友发送消息。 格式:"sendAll"
`)
}
/*
Client is a demo
*/
type Client struct {
userName string
}
type userCH chan<- string
var (
messages = make(chan string)
entering = make(chan userCH)
leaving = make(chan userCH)
clients = make(map[string]userCH)
db *sql.DB
)
func main() {
var err error
db, err = sql.Open("mysql", "root:123456@/chat?charset=utf8")
defer db.Close()
if err != nil {
log.Fatalln(err)
}
//db.Query("create database if not exists chat;")
db.Query("create table if not exists user(id int auto_increment primary key, userName varchar(20) unique, password varchar(20) not null, tel varchar(20) unique);")
listener, err := net.Listen("tcp", "localhost:8000")
if err != nil {
log.Fatal(err)
}
for {
conn, err := listener.Accept()
if err != nil {
log.Print(err)
continue
}
go handle(conn)
}
}
func handle(conn net.Conn) {
var client *Client
ch := make(chan string)
go clientWriter(conn, ch)
who := conn.RemoteAddr().String()
ch <- "welcome! " + who + "\n"
ch <- help() + "\n"
ch <- ">>"
input := bufio.NewScanner(conn)
for input.Scan() {
var op = input.Text()
fmt.Println(op)
switch op {
case "register":
//register
ok := register(input)
if !ok {
ch <- "Fail register=_=||\nmaybe your userName or phoneNumber is invalid\n"
} else {
ch <- "Success register!\n"
}
case "login":
//login
//TODO 多重登陆问题,做个强制下线?
client = login(input)
if client != nil {
clients[client.userName] = ch
ch <- "Success login!\n" + "Your new messages:\n"
client.toRead(ch)
} else {
ch <- "Fail login=_=||maybe your userName or password is wrong,please check them carefully:)\n"
}
case "logoff":
if client == nil {
ch <- "Please login first:)\n"
} else {
ch <- "you logoff successfully!\n"
ch <- help() + "\n"
clients[client.userName] = nil
client = nil
}
case "add":
var userName string
if input.Scan() {
userName = input.Text()
}
if client == nil {
ch <- "Please login first:)\n"
} else {
ok := client.addFriend(userName)
if ok {
ch <- "add successfully!\n"
} else {
ch <- "maybe no such user=_=||\n"
}
}
case "delete":
var userName string
if input.Scan() {
userName = input.Text()
}
if client == nil {
ch <- "Please login first:)\n"
} else {
ok := client.deleteFriend(userName)
if ok {
ch <- "delete successfully!\n"
} else {
ch <- "maybe no such user=_=||\n"
}
}
case "list":
if client == nil {
ch <- "Please login first:)\n"
} else {
ch <- client.list() + "\n"
}
case "sendTo":
if client == nil {
ch <- "Please login first:)\n"
} else {
var userName, content string
if input.Scan() {
userName = input.Text()
}
if input.Scan() {
content = input.Text()
}
ok := client.sendTo(userName, content)
if ok {
ch <- "发送成功\n"
} else {
ch <- "发送失败,可能你没有这个好友:)\n"
}
}
case "sendAll":
if client == nil {
ch <- "Please login first:)\n"
} else {
var content string
if input.Scan() {
content = input.Text()
}
client.sendAll(content)
}
default:
ch <- "无法识别的命令=_=||请重新输入正确的命令:)\n"
ch <- help() + "\n"
}
ch <- ">>"
}
conn.Close()
}
func clientWriter(conn net.Conn, ch chan string) {
for msg := range ch {
fmt.Fprint(conn, msg)
}
}
//读取user,判断手机号和用户名是否重复注册
//将注册信息写入user
//如果成功就创建messages、logs、friends表
func register(cin *bufio.Scanner) bool {
var userName, password, tel string
if cin.Scan() {
userName = cin.Text()
}
if cin.Scan() {
password = cin.Text()
}
if cin.Scan() {
tel = cin.Text()
}
mu.Lock()
defer mu.Unlock()
stmt, err := db.Prepare("insert user set userName=?,password=?,tel=?")
if err != nil {
log.Fatalln("in func register:insert table user")
}
res, err := stmt.Exec(userName, password, tel)
if err != nil {
fmt.Println("in func register:insert table user went wrong!!")
return false //插入失败,可能是用户名或者手机号已被注册
}
ok, _ := res.RowsAffected()
if ok != 1 { //其实这里应该不会有问题了
fmt.Println("in func register:insert table user went wrong!!")
return false //插入失败,可能是用户名或者手机号已被注册
}
fmt.Printf("%s successfully register!\n", userName)
createInfo(userName)
return true
}
//创建表
func createInfo(userName string) {
stmt, _ := db.Prepare("create table friendsOf" + userName + "(id int primary key auto_increment, userName varchar(20) unique)")
stmt.Exec()
stmt, _ = db.Prepare("create table logsOf" + userName + "(id int primary key auto_increment, wtime datetime, fromwho varchar(20), towho varchar(20), msg text)")
stmt.Exec()
stmt, _ = db.Prepare("create table messagesOf" + userName + "(id int primary key auto_increment, wtime datetime, fromwho varchar(20), msg text)")
stmt.Exec()
}
//读取user,判断是否有这个用户以及密码是否正确
//上线后自动查看收件箱,显示未读消息
func login(cin *bufio.Scanner) *Client {
var userName, password string
if cin.Scan() {
userName = cin.Text()
}
if cin.Scan() {
password = cin.Text()
}
mu.RLock()
defer mu.RUnlock()
rows, err := db.Query("select * from user where userName=? and password=?", userName, password)
if err != nil {
log.Fatalln(err)
}
if rows.Next() {
var client Client
client.userName = userName
// client.ch = make(chan string)
// go clientWriter(conn, client.ch)
// clients[userName] = client.ch
return &client
}
rows.Close()
return nil
}
func (c *Client) toRead(ch chan string) {
mu.Lock() //写日志加锁
defer mu.Unlock()
messages := "messagesOf" + c.userName
logs := "logsOf" + c.userName
rows, err := db.Query("select wtime, fromwho, msg from " + messages)
defer rows.Close()
if err != nil {
log.Fatalln("in func toRead: ", err)
}
stmt, err := db.Prepare("insert " + logs + " set wtime = ?, fromwho = ?, towho = ?, msg = ?")
for rows.Next() {
var t, who, msg string
err := rows.Scan(&t, &who, &msg)
if err != nil {
log.Fatalln("in func toRead:", err)
}
temp := t + " " + who + " " + msg
ch <- temp + "\n"
_, err = stmt.Exec(t, who, c.userName, msg)
if err != nil {
log.Fatalln("fail to insert into logsOf"+c.userName, err)
}
}
stmt, _ = db.Prepare("delete from " + messages)
_, err = stmt.Exec()
if err != nil {
log.Fatalln(err)
}
}
func (c *Client) addFriend(userName string) bool {
mu.RLock()
rows, _ := db.Query("select userName from user where userName = ?", userName)
mu.RUnlock()
defer rows.Close()
if rows.Next() {
//这个应该不用锁,只有自己读写
stmt, err := db.Prepare("insert friendsOf" + c.userName + " set userName = ?")
if err != nil {
log.Fatalln(err)
}
_, err = stmt.Exec(userName)
if err != nil {
log.Fatalln(err)
}
return true
}
return false
}
func (c *Client) deleteFriend(userName string) bool {
rows, err := db.Query("select userName from friendsOf"+c.userName+" where userName = ?", userName)
defer rows.Close()
if err != nil {
log.Fatalln(err)
}
if rows.Next() {
stmt, _ := db.Prepare("delete from friendsOf" + c.userName + " where userName = ?")
_, err := stmt.Exec(userName)
if err != nil {
log.Fatalln(err)
}
return true
}
return false
}
func (c *Client) list() string {
res := "Your friends list:"
rows, _ := db.Query("select userName from friendsOf" + c.userName)
var userName string
for rows.Next() {
rows.Scan(&userName)
res += "\n\t" + userName
}
return res
}
func (c *Client) sendTo(userName, content string) bool {
mu.Lock()
defer mu.Unlock()
err := db.QueryRow("select userName from friendsOf"+c.userName+" where userName = ?", userName).Scan(&userName)
if err != nil {
return false
}
t := time.Now().Format("2006-01-02 15:04:05")
msg := "\n" + t + " " + c.userName + " " + content
if clients[userName] != nil {
clients[userName] <- msg + "\n"
clients[userName] <- ">>"
stmt, err := db.Prepare("insert logsOf" + userName + " set wtime=?, fromwho=?, towho=?, msg=?")
_, err = stmt.Exec(t, c.userName, userName, content)
if err != nil {
log.Fatalln("in func sendTo: table(logsOf"+userName+"): ", err)
}
} else {
stmt, err := db.Prepare("insert messagesOf" + userName + " set wtime=?, fromwho=?, msg=?")
_, err = stmt.Exec(t, c.userName, content)
if err != nil {
log.Fatalln("in func sendTo: table(messagesOf"+userName+"): ", err)
}
}
stmt, err := db.Prepare("insert logsOf" + c.userName + " set wtime=?, fromwho=?, towho=?, msg=?")
_, err = stmt.Exec(t, c.userName, userName, content)
if err != nil {
log.Fatalln("in func sendTo: table(logsOf"+c.userName+"): ", err)
}
return true
}
func (c *Client) sendAll(content string) bool {
mu.Lock()
defer mu.Unlock()
t := time.Now().Format("2006-01-02 15:04:05")
msg := "\n" + t + " " + c.userName + " " + content
rows, _ := db.Query("select userName from friendsOf" + c.userName)
defer rows.Close()
var userName string
for rows.Next() {
rows.Scan(&userName)
if clients[userName] != nil {
clients[userName] <- msg + "\n"
clients[userName] <- ">>"
stmt, err := db.Prepare("insert logsOf" + userName + " set wtime=?, fromwho=?, towho=?, msg=?")
_, err = stmt.Exec(t, c.userName, userName, content)
if err != nil {
log.Fatalln("in func sendAll: insert logsOf: ", err)
}
} else {
stmt, err := db.Prepare("insert messagesOf" + userName + " set wtime=?, fromwho=?, msg=?")
_, err = stmt.Exec(t, c.userName, userName, content)
if err != nil {
log.Fatalln("in func sendAll: insert messagesOf: ", err)
}
}
stmt, err := db.Prepare("insert logsOf" + c.userName + " set wtime=?, fromwho=?, towho=?, msg=?")
_, err = stmt.Exec(t, c.userName, userName, content)
if err != nil {
log.Fatalln("in func sendAll: insert logsOfsender: ", err)
}
}
return true
}
client.go
package main
import (
"bufio"
"fmt"
"io"
"log"
"net"
"os"
"strings"
)
var ch = make(chan string)
var cin = bufio.NewScanner(os.Stdin)
func main() {
conn, err := net.Dial("tcp", "localhost:8000")
defer conn.Close()
if err != nil {
log.Fatal(err)
}
done := make(chan struct{})
go func() {
io.Copy(os.Stdout, conn) // NOTE: ignoring errors
log.Println("done")
done <- struct{}{} // signal the main goroutine
}()
go clientWriter(conn, ch)
for cin.Scan() {
op := cin.Text()
op = strings.Replace(op, " ", "", -1)
switch op {
case "register":
register()
case "login":
login()
case "logoff":
ch <- "logoff"
case "exit":
os.Exit(0)
case "add":
add()
case "delete":
delete()
case "list":
ch <- "list"
case "sendTo":
sendTo()
case "sendAll":
sendAll()
default:
ch <- "none"
}
}
//conn.Close() //前面用了defer
<-done // wait for background goroutine to finish
}
func clientWriter(conn net.Conn, ch <-chan string) {
for msg := range ch {
fmt.Fprintln(conn, msg)
}
}
func register() {
ch <- "register"
var userName, password, tel string
fmt.Print("请输入用户名:")
if cin.Scan() {
userName = cin.Text()
}
fmt.Print("请输入密码:")
if cin.Scan() {
password = cin.Text()
}
fmt.Print("请输入手机号:")
if cin.Scan() {
tel = cin.Text()
}
ch <- userName
ch <- password
ch <- tel
}
func login() {
ch <- "login"
var userName, password string
fmt.Print("请输入用户名:")
if cin.Scan() {
userName = cin.Text()
}
fmt.Print("请输入密码:")
if cin.Scan() {
password = cin.Text()
}
ch <- userName
ch <- password
}
func add() {
var userName string
fmt.Print("请输入要添加好友的用户名:")
if cin.Scan() {
userName = cin.Text()
}
ch <- "add"
ch <- userName
}
func delete() {
fmt.Print("请输入要删除好友的用户名:")
var userName string
if cin.Scan() {
userName = cin.Text()
}
ch <- "delete"
ch <- userName
}
func sendTo() {
var userName, content string
fmt.Print("请输入好友用户名:")
if cin.Scan() {
userName = cin.Text()
}
fmt.Print("请输入消息:")
if cin.Scan() {
content = cin.Text()
}
ch <- "sendTo"
ch <- userName
ch <- content
}
func sendAll() {
var content string
fmt.Print("请输入消息:")
if cin.Scan() {
content = cin.Text()
}
ch <- "sendAll"
ch <- content
}