#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#include <signal.h>
#include <setjmp.h>
#include <sys/syscall.h>
int main_id, thread_id;
void exception_handler_main(int sig)
{
int id = (int)syscall(SYS_gettid);
fprintf(stderr, "Got signal %d, id %d\n", sig, id);
sigrelse(sig);
if (id == main_id)
longjmp(jmp_buf_main, 1);
else if (id == thread_id)
longjmp(jmp_buf_thread, 1);
else {
fprintf(stderr, "unknown error\n");
exit(-1);
}
}
void *thread_function(void *arg);
int main()
{
main_id = (int)syscall(SYS_gettid);
fprintf(stderr, "main id: %d\n", main_id);
int res;
pthread_t a_thread;
void *thread_result;
signal(SIGSEGV, exception_handler_main);
res = pthread_create(&a_thread, NULL, thread_function, NULL);
if (res != 0) {
perror("Thread creation failed");
exit(EXIT_FAILURE);
}
int i = 0;
int *p = NULL;
for(i=0; i<3; i++) {
if (setjmp(jmp_buf_main)==0) {
*p = 10;
} else {
fprintf(stderr, "return to main\n");
}
sleep(1);
}
printf("Waiting for thread to finish...\n");
res = pthread_join(a_thread, &thread_result);
if (res != 0) {
perror("Thread join failed");
exit(EXIT_FAILURE);
}
printf("Thread joined, it returned %s\n", (char *)thread_result);
return 0;
}
void *thread_function(void *arg) {
int i = 0;
int *p = NULL;
thread_id = (int)syscall(SYS_gettid);
fprintf(stderr, "thread id: %d\n", thread_id);
while(1) {
for(i=0; i<3; i++) {
fprintf(stderr, "thread: %d\n", i);
if (setjmp(jmp_buf_thread)==0) {
*p = 10;
} else {
fprintf(stderr, "return to thread\n");
}
sleep(1);
}
}
return NULL;
}