服务段代码:
package main
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
)
func initEnv(i ...int) {
c:=1
if len(i) == 0 {
if runtime.GOOS=="linux"{
t, _ := exec.Command("/bin/bash","-c","cat /proc/cpuinfo | grep 'processor' | sort | uniq | wc -l").Output()
c, _ =strconv.Atoi(string(t))
}else {
c, _ = strconv.Atoi(os.Getenv("number_of_processors"))
}
}else{
c=i[0]
}
fmt.Println(c)
runtime.GOMAXPROCS(c)
}
func sendAll(conn net.Conn, data []byte) {
length := len(data)
count := 0
for count < length {
n, _ := conn.Write(data[count:])
count += n
}
}
func getIndex(str string, c rune) int {
for i, s := range str {
if s == c {
return i
}
}
return -1
}
func rIndex(str string, c uint8) int {
for i := len(str) - 1; i >= 0; i-- {
if str[i] == c {
return i
}
}
return -1
}
func getDir(path string) string {
path = strings.ReplaceAll(path, "\\", "/")
dir := path[:rIndex(path, '/')]
return dir
}
func handle(conn net.Conn) {
//_ = conn.SetReadDeadline(time.Now().Add(time.Second * 60*5))
//_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 60*5))
defer func() {
if err := recover(); err != nil {
fmt.Println(err)
fmt.Println("exit")
return
}
}()
defer conn.Close()
data := make([]byte, 40960)
n, err := conn.Read(data)
if err != nil {
return
}
res := bytes.Split(data[:n], []byte("|"))
if len(res) < 2 {
return
}
mode := string(res[0])
file := string(res[1])
down := "download"
upload := "upload"
if mode == down {
//下载处理
fmt.Println("download.....", string(file))
f, _ := os.Open(string(file))
defer f.Close()
if f != nil { //file is not nil , and then send file data
var buffer = make([]byte, 1024)
n, _ := f.Read(buffer)
for n != 0 {
sendAll(conn, buffer[:n])
n, err = f.Read(buffer)
}
} else {
fmt.Println("读取文件失败", string(file))
}
} else if mode == upload {
//上传处理
fmt.Println("upload...", file)
//开始创建保存目录
dir := getDir(file)
fmt.Println("dir:", dir)
_, err = os.Stat(dir)
if os.IsNotExist(err) {
e := os.MkdirAll(dir, os.ModePerm)
if e != nil {
fmt.Printf("不能创建目录")
panic(e)
}
}
var f *os.File
var err error
if b,_:=pathExists(file);b{
fmt.Println(b,"exist")
f, err = os.OpenFile(file,syscall.O_WRONLY,0)
}else {
f, err = os.Create(file) //创建文件
}
defer func() {
if f != nil {
f.Close()
}
}()
if err != nil {
fmt.Println("文件创建失败",file)
panic(fmt.Sprintf("%v%s",err,"文件创建失败"))
}
buffer := make([]byte, 1024)
for {
n, err := conn.Read(buffer)
if err != nil {
break
}
_, _ = f.Write(buffer[:n])
}
fmt.Println("upload ok")
}
fmt.Println("done")
}
func pathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func run(host,port,sysType,interpreter,arg1,cmdGetPid string) {
start:
tcpServer, _ := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s",host,port))
listener, _ := net.ListenTCP("tcp", tcpServer)
fmt.Println("run server ... ")
retry:=20
for {
//当有新的客户端请求来的时候,拿到与客户端的连接
conn, err := listener.Accept()
if err != nil {
if retry<=0{
panic("端口被占用")
}
if sysType == "linux"{
_ = exec.Command(interpreter,arg1,cmdGetPid).Run()
retry--
goto start
}
getPID:=exec.Command(interpreter,arg1,fmt.Sprintf(cmdGetPid,port))
byteRes, _ :=getPID.Output()
res:=string(byteRes)
split:="\n"
if strings.Contains(res,"\r\n"){
split="\r\n"
}
pidInfoList :=strings.Split(res,split)
pidSet :=make(map[string]bool)
for _,line:=range pidInfoList {
if len(line)<3{
continue
}
pid:=line[rIndex(line,' ')+1:]
if _,exist:= pidSet[pid];!exist{
kill:=exec.Command("cmd.exe","/c",fmt.Sprintf("taskkill /f /t /im %s",pid))
_ = kill.Run()
fmt.Println(pid)
pidSet[pid]=true
}
}
retry--
goto start
}
go handle(conn)
}
}
func GetAllFile(res **[]string ,pathname string) error {
rd, err := ioutil.ReadDir(pathname)
for _, fi := range rd {
if fi.IsDir() {
_ = GetAllFile(res,pathname + "/" + fi.Name() )
} else {
t:=append(**res,pathname+"/"+fi.Name())
*res=&t
}
}
return err
}
func toList(files []string) string {
var buffer bytes.Buffer
for _,s:=range files {
buffer.WriteString(fmt.Sprintf("'%s',",s))
}
res:="["+buffer.String()+"]"
fmt.Println(res)
return res
}
func main() {
initEnv()
host:=""
port:="8080"
portControl :="8081"
if len(os.Args)>3{
host=os.Args[1]
port=os.Args[2]
}
sysType := runtime.GOOS
fmt.Println(sysType)
interpreter :="cmd.exe"
arg1:="/c"
cmdGetPid:="netstat -ano|findstr %s"
if sysType == "linux" {
// LINUX系统
interpreter="/bin/bash"
arg1="-c"
cmdGetPid="a=$(netstat -anp|grep %s|awk -F '[ /]' '{print $(NF-1)}');for i in ${a[@]};do;kill $i;done"
}
go run(host,port,sysType,interpreter,arg1,cmdGetPid)
tcpServer, _ := net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%s",host, portControl))
listener, _ := net.ListenTCP("tcp", tcpServer)
buffer:=make([]byte,4096)
for {
conn, err := listener.Accept()
if err != nil {
return
}
n,_:=conn.Read(buffer)
res:=string(buffer[:n])
var code, args string
list:=strings.Split(res,"|")
fmt.Println(list,len(list))
if len(list) == 1 {
code=list[0]
}else if len(list)>1{
fmt.Println("len(list)>1")
code=list[0]
args=list[1]
}else {
continue
}
fmt.Println(res,code,list)
switch code {
case "get_files":
fmt.Println(args)
t:=make([]string,0)
files:=&t
pFiles :=&files
_ = GetAllFile(pFiles,args)
//fmt.Print(*files)
sendAll(conn, []byte(toList(*files)))
conn.Close()
break
case "exit":
fmt.Println("exit ----------------------")
return
}
}
}
客户端:
import os
import socket
import threading
def download(target, local):
if not os.path.isfile(target):
return
print(target, local)
client = socket.socket()
while True:
try:
client.connect(('127.0.0.1', 8080))
break
except Exception as e:
print(e)
pass
client.sendall(bytes("download|%s" % target, encoding='utf8'))
dir_name, _ = os.path.split(local)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
with open(local, 'wb')as f:
fragment = client.recv(1024)
while fragment:
f.write(fragment)
fragment = client.recv(1024)
print('退出了')
client.close()
def upload(local, target):
client = socket.socket()
client.connect(('127.0.0.1', 8080))
client.sendall(bytes('upload|%s' % target, encoding='utf8'))
with open(local, 'rb')as f:
content = f.read(1024)
while content:
client.sendall(content)
content = f.read(1024)
client.close()
def get_files(path):
client = socket.socket()
client.connect(('127.0.0.1', 8081))
path = "get_files|%s" % path
client.sendall(path.encode('utf8'))
files = []
content = client.recv(1024)
while content:
files.append(content)
content = client.recv(1024)
client.close()
return b''.join(files)
def close():
client = socket.socket()
client.connect(('127.0.0.1', 8081))
client.sendall(b'exit')
# file_list = get_files('G:/Driver驱动')
file_list = get_files('G:/test000')
save_list = map(lambda x: "F" + x[1:], file_list)
threads = []
for i in zip(file_list, save_list):
threads.append(threading.Thread(target=download, args=i))
for t in threads:
t.setDaemon(False)
t.start()
res = get_files('G:/test000')
for i in eval(res):
print(i)
close()