最近有一个的需求,需要将一些linux用户态的命令做成自动化。
比如在用户态执行lspci命令,判断获取的设备中是否有某个型号的pci卡,这就需要linux内核态和用户态交互。实现的方法是通过linux内核态编程。在内核驱动中通过socket发送一个用户态请求,server端接收到请求并执行,执行后将结果返回给内核驱动,驱动中判断结果。
内核态socket编程的过程和用户态下的socket编程流程一样,但是接口不同。Kernel提供了一组内核态的socket API,基本上在用户态的sockt API在内核中都有对应的API。
在net/socket.c中可以看到导出符号:
主要实现:
server端代码:
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <ctype.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <errno.h>
#include "public.h"
#include <sys/time.h>
#include <sys/ioctl.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <errno.h>
int server_listen_fd;
int server_accept_fd;
int port= 2002;
int get_current_time()
{
struct timeval stime;
gettimeofday( &stime, NULL);
return stime.tv_sec * 1000000 + stime.tv_usec;
}
int excute_cmd(char* cmd, char* result) {
char buffer[1024]; //定义缓冲区
FILE* pipe = popen(cmd, "r");
if (!pipe)
return 1;
while (!feof(pipe)) {
if (fgets(buffer, 1024, pipe)){
strcat(result,buffer);
}
}
pclose(pipe); //关闭管道
return 0;
}
int server_recv(int fd, cmd_request *request, int len_recv, int timeout)
{
int len;
unsigned int recved = 0;
unsigned long last_time = get_current_time();
int temp_time;
printf("server recv start %d time:%d!\n", len_recv, timeout);
while (1) {
temp_time = get_current_time() - last_time;
if (get_current_time() - last_time > timeout * 1000) {
printf("server recv timeout!\n");
break;
}
len = recv(fd, request->reqbuf + recved, len_recv - recved, MSG_DONTWAIT);
printf("recv buf is %s:\n",request->reqbuf);
if (len <= 0) {
printf("server recv error!\n");
return -1;
}
printf("server has recv %d\n",len);
recved += len;
if (recved >= len_recv) {
printf("recved %d bytes!\n", recved);
return recved;
}
}
return 0;
}
int server_accept()
{
int size;
int opt = 1;
int flags;
struct sockaddr_in server_accept_addr;
bzero(&server_accept_addr, sizeof(server_accept_addr));
size = sizeof(server_accept_addr);
flags = fcntl(server_accept_fd, F_GETFL, 0);
fcntl(server_accept_fd, F_SETFL, flags | O_NONBLOCK);
setsockopt(server_accept_fd, IPPROTO_TCP, TCP_NODELAY, (char *)&opt,sizeof(int));
Step1:
server_accept_fd = accept(server_listen_fd, (struct sockaddr*)&server_accept_addr, &size);
if (server_accept_fd < 0) {
if (errno == EAGAIN) {
goto Step1;
}
perror("error:socket accept1 exited!\n");
exit(1);
}
}
int server_init(char *ip)
{
int opt =1;
int flags;
struct sockaddr_in server_listen_addr;
bzero(&server_listen_addr, sizeof(server_listen_addr));
server_listen_addr.sin_family = AF_INET;
server_listen_addr.sin_addr.s_addr = inet_addr(ip);
server_listen_addr.sin_port = htons(port);
setsockopt(server_listen_fd, SOL_SOCKET, SO_REUSEADDR,(char *)&opt, sizeof(int));
if (bind(server_listen_fd, (struct sockaddr*)&server_listen_addr,
sizeof(server_listen_addr)) == -1) {
perror("can't to bind");
exit(1);
}
flags = fcntl(server_listen_fd, F_GETFL, 0);
fcntl(server_listen_fd, F_SETFL, flags | O_NONBLOCK);
if (listen(server_listen_fd, 10) == -1) {
perror("can't to bind");
exit(1);
}
return 0;
}
int main(int argc,char *argv[])
{
int ret;
int flags;
int timeout = 1000;
struct sockaddr_in server_send_addr;
cmd_request request;
cmd_response response;
//printf("argv is %s\n", argv[1]);
if (!argv[1]) {
printf("need ip!\n");
return -1;
}
bzero(&server_send_addr, sizeof(server_send_addr));
memset(&request, 0 ,sizeof(request));
memset(&response, 0 ,sizeof(response));
server_listen_fd = socket(AF_INET,SOCK_STREAM,0);
if (-1 == server_listen_fd) {
perror("fail to create socket!");
exit(1);
}
server_init(argv[1]);
while (1) {
printf("server socket begin accept:\n");
server_accept();
//recv
ret = server_recv(server_accept_fd, &request, sizeof(cmd_request), timeout);
if (ret <= 0) {
printf("server recv error!\n");
}
//ret = recv(server_accept_fd, request.reqbuf, 1024, 0);
printf("DATA:[%s]\n", request.reqbuf);
excute_cmd(request.reqbuf,response.rspbuf);
printf("the result is %s",response.rspbuf);
//send
ret = send(server_accept_fd, response.rspbuf, sizeof(response.rspbuf), 0);
if (ret <= 0){
printf("send %d failed!\n",ret);
}
}
close(server_accept_fd);
return 0;
}
客户端实现:
#include <linux/module.h>
#include <linux/init.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <linux/in.h>
#include <linux/tcp.h>
#include <linux/in.h>
#include <linux/inet.h>
#include <linux/time.h>
#include "public.h"
int port_id = 2002;
char * dst_ip = "192.168.0.103";
module_param(dst_ip, charp, S_IRUSR);
cmd_request *prequest;
cmd_response *presponse;
int client_init(void);
int plugin_get_current_time()
{
struct timeval stime;
do_gettimeofday(&stime);
return (stime.tv_sec * 1000000 + stime.tv_usec) / 1000;
}
int client_send(struct socket *sock, unsigned char *pbufsend, int len_send)
{
int len;
int sended = 0;
unsigned long last_time = plugin_get_current_time();
struct kvec vec;
struct msghdr msg;
unsigned short timeout = 1000;
while (1) {
if (plugin_get_current_time() - last_time > timeout){
printk("kernel send msg timeout!\n");
break;
}
vec.iov_base = pbufsend + sended;
vec.iov_len = len_send - sended;
msg.msg_flags = 0;
len = kernel_sendmsg(sock, &msg, &vec, 1, len_send - sended);
if (len < 0){
if (len == -EWOULDBLOCK){
printk("kernel send msg would block!\n");
continue;
} else {
printk("kernel send msg failed with < 0!\n");
return -EINVAL;
}
} else if (len == 0){
printk("kernel send msg failed with = 0!\n");
return -EINVAL;
}
sended += len;
if (sended >= len_send) {
printk("kernel send %d bytes!\n", sended);
return sended;
}
}
}
int client_init(void)
{
struct socket *sock;
struct sockaddr_in s_addr;
int ret = 0;
sock = (struct socket*)kmalloc(sizeof(struct socket), GFP_KERNEL);
memset(&s_addr, 0, sizeof(s_addr));
s_addr.sin_family = AF_INET;
s_addr.sin_port = htons(port_id);
s_addr.sin_addr.s_addr = in_aton(dst_ip);
/*create socket*/
ret = sock_create_kern(AF_INET, SOCK_STREAM, 0, &sock);
if (ret) {
printk("socket create failed\n");
return ret;
}
printk("create socket ok!\n");
/*connect server*/
ret = sock->ops->connect(sock, (struct sockaddr*)&s_addr, sizeof(struct sockaddr_in), 0);
if (ret) {
printk("socket connect server failed!\n");
return ret;
}
printk("connect server ok!\n");
// set opt
int opt = 1;
int flags;
kernel_setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&opt, sizeof(int));
flags = kernel_sock_ioctl(sock, F_GETFL, 0);
kernel_sock_ioctl(sock, F_SETFL, flags | O_NONBLOCK);
/*kmalloc sendbuf*/
char *sendbuf = NULL;
sendbuf = kmalloc(1024, GFP_KERNEL);
memset(sendbuf, 0, 1024);
strcpy(sendbuf, "lspci");
printk("the request is %s, size is %d\n", sendbuf, sizeof(sendbuf));
ret = client_send(sock, sendbuf, 1024);
if (ret <= 0) {
printk("client send failed !\n");
return -EINVAL;
}
recvbuf = kmalloc(1024, GFP_KERNEL);
memset(recvbuf, 0, 1024);
memset(&msg, 0, sizeof(msg));
memset(&vec, 0, sizeof(vec));
vec.iov_base = recvbuf;
vec.iov_len=1024;
int count = 0;
while (count < 1000) {
ret = kernel_recvmsg(sock, &msg, &vec, 1, 1024, 0);
if(ret < 0){
printk("client:kernel_sendmsg error!\n");
return ret;
} else if (ret > 0) {
printk("recv message %s\n",recvbuf);
break;
}
count ++;
}
if (count >= 1000)
printk("kernel recv msg timeout!\n");
//判断结果是否符合预期
char *expect = "PCI";
if (strstr(recvbuf, expect) != NULL) {
printk("lspci test pass!\n");
}
return ret;
}
#if 0
int tc_run(skip_tc *pskip_tc)
{
socket *psock;
pskip_tc->desc = psock;
client_init();
client_recv_and_send(pskip_tc,pskip_tc->request,sizeof(cmd_request),pskip_tc->response,sizeof(cmd_response));
}
#endif
static int __init plugin_init(void)
{
printk("hello, plugin!\n");
client_init();
return 0;
}
static void __exit plugin_exit(void)
{
printk("goodbye,plugin!\n");
}
module_init(plugin_init);
module_exit(plugin_exit);
MODULE_LICENSE("GPL");
使用示例:
服务端:
客户端: