tftp 协议的客户端下载和上传
#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#define ERR_MSG(msg) \
do { \
fprintf(stderr, "line %d: ", __LINE__); \
perror(msg); \
} while (0)
#define PORT 69
/**
* @brief 打印菜单
*/
void print_menu() {
printf("Select: \n");
printf("\t1. Download a file\n");
printf("\t2. Upload a file\n");
printf("\t3. help\n");
printf("\t4. quit\n");
}
/**
* @brief 用户选择功能
* @return 用户选择的数字
*/
int prompt() {
int num;
printf("> ");
scanf("%d", &num);
return num;
}
/**
* @brief 下载文件
* @param cfd 套接字
* @param sin 服务器地址信息结构体
* @param slen 服务器地址信息结构体的大小
* @return 0 成功, -1 失败
*/
int download(int cfd, struct sockaddr_in *sin, size_t slen) {
int fd; /* 保存的文件 */
int flag = 0; /* 函数返回的状态 */
int size, code, num; /* 收发缓存长度,操作数,数据包编码 */
char file_name[100]; /* 文件名 */
char buf[516] = {0}; /* 收发缓存 */
/* 保存服务器提供下载的端口 */
struct sockaddr_in cin;
cin.sin_family = AF_INET;
int cin_size;
/* 请求用户输入文件名称 */
printf("please input file name: ");
scanf("%99s", file_name); /* 文件名长度限制 99 个字符 */
/* 发送下载请求 */
size = sprintf(buf, "%c%c%s%c%s%c", 0x00, 0x01, file_name, 0, "octet", 0);
if (sendto(cfd, buf, size, 0, (struct sockaddr *)sin, slen) < 0) {
ERR_MSG("sendto");
return -1;
}
while (1) {
/* 接收数据包 */
size = recvfrom(cfd, buf, 516, 0, (struct sockaddr *)&cin, &cin_size);
if (size < 0) {
ERR_MSG("recvfrom");
flag = -1;
break;
}
code = ntohs(*(unsigned short *)buf); /* 操作码 */
num = ntohs(*(unsigned short *)(buf + 2)); /* 包编号 */
if (code == 5) {
printf("Fail: %s\n", buf + 4);
flag = -1;
break;
}
if (code != 3) {
printf("download failed.\n");
flag = -1;
break;
}
/* 第一个数据包的编码是 1 */
if (num == 1) {
/* 打开保存的文件 */
fd = open(file_name, O_RDWR | O_CREAT | O_TRUNC, 0666);
if (fd < 0) {
ERR_MSG("open");
return -1;
}
}
write(fd, buf + 4, size - 4);
bzero(buf, sizeof(buf));
*(unsigned short *)buf = htons(4);
*(unsigned short *)(buf + 2) = htons(num);
if (sendto(cfd, buf, 4, 0, (struct sockaddr *)&cin, cin_size) < 0) {
ERR_MSG("sendto");
flag = -1;
break;
}
if (size < 516) {
printf("download finished.\n");
break;
}
}
close(fd);
return flag;
}
/**
* @brief 上传文件
* @param cfd 套接字
* @param sin 服务器地址信息结构体
* @param slen 服务器地址信息结构体的大小
* @return 0 成功, -1 失败
*/
int upload(int cfd, struct sockaddr_in *sin, size_t slen) {
int fd; /* 上传的文件 */
int flag = 0; /* 函数返回的状态 */
int size, code, num; /* 收发缓存长度,操作数,数据包编码 */
char file_name[100]; /* 文件名 */
char buf[516] = {0}; /* 收发缓存 */
/* 保存服务器提供上传的端口 */
struct sockaddr_in cin;
cin.sin_family = AF_INET;
int cin_size;
/* 请求用户输入文件名称 */
printf("please input file name: ");
scanf("%99s", file_name); /* 文件名长度限制 99 个字符 */
/* 打开文件 */
fd = open(file_name, O_RDONLY);
if (fd < 0) {
ERR_MSG("open");
return -1;
}
/* 发送上传请求 */
size = sprintf(buf, "%c%c%s%c%s%c", 0x00, 0x02, file_name, 0, "octet", 0);
if (sendto(cfd, buf, size, 0, (struct sockaddr *)sin, slen) < 0) {
ERR_MSG("sendto");
return -1;
}
while (1) {
/* 接收 ACK 包 */
size = recvfrom(cfd, buf, 516, 0, (struct sockaddr *)&cin, &cin_size);
if (size < 0) {
ERR_MSG("recvfrom");
flag = -1;
break;
}
code = ntohs(*(unsigned short *)buf); /* 操作码 */
num = ntohs(*(unsigned short *)(buf + 2)); /* 包编号 */
if (code == 5) {
printf("Fail: %s\n", buf + 4);
flag = -1;
break;
}
if (code != 4) {
printf("upload failed.\n");
flag = -1;
break;
}
/* Code == 4,说明上一个包上传成功,发送下一个包 */
bzero(buf, sizeof(buf));
size = read(fd, buf + 4, 512);
if (size < 0) {
ERR_MSG("read");
flag = -1;
break;
}
*(unsigned short *)buf = htons(3);
*(unsigned short *)(buf + 2) = htons(num + 1);
if (sendto(cfd, buf, size + 4, 0, (struct sockaddr *)&cin, cin_size) < 0) {
ERR_MSG("sendto");
flag = -1;
break;
}
if (size < 512) {
printf("upload finished.\n");
break;
}
}
close(fd);
return flag;
}
int main(int argc, char const *argv[]) {
if (argc != 2) {
printf("usage: %s ip\n", argv[0]);
return -1;
}
/* 创建报式套接字 */
int cfd = socket(AF_INET, SOCK_DGRAM, 0);
if (cfd < 0) {
ERR_MSG("socket");
return -1;
}
/* 填充服务器的地址信息结构体 */
struct sockaddr_in sin;
sin.sin_family = AF_INET; /* 必须填充 AF_INET */
sin.sin_port = htons(PORT); /* tftp server port --> 69 */
sin.sin_addr.s_addr = inet_addr(argv[1]); /* tftp server ip */
print_menu();
while (1) {
int select = prompt();
if (select == 1) {
printf("select download...\n");
if (download(cfd, &sin, sizeof(sin)) < 0) {
ERR_MSG("download");
break;
}
} else if (select == 2) {
printf("select upload...\n");
if (upload(cfd, &sin, sizeof(sin)) < 0) {
ERR_MSG("upload");
break;
}
} else if (select == 3) {
print_menu();
} else if (select == 4) {
printf("see you!\n");
close(cfd);
break;
}
}
}