#lang scheme
( define ( make-node point axis left-child right-child )
( define ( dispatch msg . args )
( cond
[ ( eq? msg 'point ) point ]
[ ( eq? msg 'axis ) axis ]
[ ( eq? msg 'left-child ) left-child ]
[ ( eq? msg 'right-child ) right-child ] ) )
dispatch )
( define ( square-distance lst1 lst2 )
( cond
[ ( null? lst1 ) 0 ]
[ else
( + ( expt ( - ( car lst1 )( car lst2 ) ) 2 )
( square-distance ( cdr lst1 )( cdr lst2 ) ) ) ] ) )
( define ( make-kd-tree )
( let ( [ root '() ] )
( define ( build-tree point-list )
( let ( [ point-vector ( list->vector point-list ) ] )
( define ( build point-vector depth )
( let* ( [ point-vector-length ( vector-length point-vector ) ]
[ median-index ( quotient point-vector-length 2 ) ] )
( cond
[ ( = point-vector-length 0 ) 'nil ]
[ else
( let* ( [ dimension ( length ( vector-ref point-vector 0 ) ) ]
[ axis ( modulo depth dimension ) ]
[ point-vector ( list->vector ( sort ( vector->list point-vector )
#:key ( lambda ( item )( list-ref item axis ) ) < ) ) ] )
( make-node ( vector-ref point-vector median-index )
axis
( build ( vector-copy point-vector
0
median-index )
( + depth 1 ) )
( build ( vector-copy point-vector
( + median-index 1 )
point-vector-length )
( + depth 1 ) ) ) ) ] ) ) )
( set! root ( build point-vector 0 ) ) ) )
( define ( search-nearest query-point )
( let ( [ best-point ( make-hash ) ] )
( dict-set*! best-point "point" 'nil "distance" +inf.0 )
( define ( search this-node )
( cond
[ ( eq? this-node 'nil )( void ) ]
[ else
( let* ( [ point ( this-node 'point ) ]
[ axis ( this-node 'axis ) ]
[ left-child ( this-node 'left-child ) ]
[ right-child ( this-node 'right-child ) ]
[ distance ( square-distance point query-point ) ]
[ axis-dimension-diff ( - ( list-ref query-point axis )
( list-ref point axis ) ) ]
[ square-axis-dimension-diff ( * axis-dimension-diff
axis-dimension-diff ) ] )
( cond
[ ( < distance ( dict-ref best-point "distance" ) )
( dict-set! best-point "point" point )
( dict-set! best-point "distance" distance ) ]
[ else ( void ) ] )
( cond
[ ( <= axis-dimension-diff 0 )
( search left-child )
( cond
[ ( < square-axis-dimension-diff ( dict-ref best-point "distance" ) )
( search right-child ) ]
[ else ( void ) ] ) ]
[ else
( search right-child )
( cond
[ ( < square-axis-dimension-diff ( dict-ref best-point "distance" ) )
( search left-child ) ]
[ else ( void ) ] ) ] ) ) ] ) )
( search root )
best-point ) )
( define ( dispatch msg . args )
( cond
[ ( eq? msg 'build-tree )( build-tree ( car args ) ) ]
[ ( eq? msg 'search-nearest )( search-nearest ( car args ) ) ] ) )
dispatch ) )
( define tree ( make-kd-tree ) )
( tree 'build-tree '( ( 2 3 )( 5 4 )( 9 6 )( 4 7 )( 8 1 )( 7 2 ) ) )
( define res ( tree 'search-nearest '( 2.1 3.1 ) ) )
( sqrt ( dict-ref res "distance" ) )
( define ( make-node point axis left-child right-child )
( define ( dispatch msg . args )
( cond
[ ( eq? msg 'point ) point ]
[ ( eq? msg 'axis ) axis ]
[ ( eq? msg 'left-child ) left-child ]
[ ( eq? msg 'right-child ) right-child ] ) )
dispatch )
( define ( square-distance lst1 lst2 )
( cond
[ ( null? lst1 ) 0 ]
[ else
( + ( expt ( - ( car lst1 )( car lst2 ) ) 2 )
( square-distance ( cdr lst1 )( cdr lst2 ) ) ) ] ) )
( define ( make-kd-tree )
( let ( [ root '() ] )
( define ( build-tree point-list )
( let ( [ point-vector ( list->vector point-list ) ] )
( define ( build point-vector depth )
( let* ( [ point-vector-length ( vector-length point-vector ) ]
[ median-index ( quotient point-vector-length 2 ) ] )
( cond
[ ( = point-vector-length 0 ) 'nil ]
[ else
( let* ( [ dimension ( length ( vector-ref point-vector 0 ) ) ]
[ axis ( modulo depth dimension ) ]
[ point-vector ( list->vector ( sort ( vector->list point-vector )
#:key ( lambda ( item )( list-ref item axis ) ) < ) ) ] )
( make-node ( vector-ref point-vector median-index )
axis
( build ( vector-copy point-vector
0
median-index )
( + depth 1 ) )
( build ( vector-copy point-vector
( + median-index 1 )
point-vector-length )
( + depth 1 ) ) ) ) ] ) ) )
( set! root ( build point-vector 0 ) ) ) )
( define ( search-nearest query-point )
( let ( [ best-point ( make-hash ) ] )
( dict-set*! best-point "point" 'nil "distance" +inf.0 )
( define ( search this-node )
( cond
[ ( eq? this-node 'nil )( void ) ]
[ else
( let* ( [ point ( this-node 'point ) ]
[ axis ( this-node 'axis ) ]
[ left-child ( this-node 'left-child ) ]
[ right-child ( this-node 'right-child ) ]
[ distance ( square-distance point query-point ) ]
[ axis-dimension-diff ( - ( list-ref query-point axis )
( list-ref point axis ) ) ]
[ square-axis-dimension-diff ( * axis-dimension-diff
axis-dimension-diff ) ] )
( cond
[ ( < distance ( dict-ref best-point "distance" ) )
( dict-set! best-point "point" point )
( dict-set! best-point "distance" distance ) ]
[ else ( void ) ] )
( cond
[ ( <= axis-dimension-diff 0 )
( search left-child )
( cond
[ ( < square-axis-dimension-diff ( dict-ref best-point "distance" ) )
( search right-child ) ]
[ else ( void ) ] ) ]
[ else
( search right-child )
( cond
[ ( < square-axis-dimension-diff ( dict-ref best-point "distance" ) )
( search left-child ) ]
[ else ( void ) ] ) ] ) ) ] ) )
( search root )
best-point ) )
( define ( dispatch msg . args )
( cond
[ ( eq? msg 'build-tree )( build-tree ( car args ) ) ]
[ ( eq? msg 'search-nearest )( search-nearest ( car args ) ) ] ) )
dispatch ) )
( define tree ( make-kd-tree ) )
( tree 'build-tree '( ( 2 3 )( 5 4 )( 9 6 )( 4 7 )( 8 1 )( 7 2 ) ) )
( define res ( tree 'search-nearest '( 2.1 3.1 ) ) )
( sqrt ( dict-ref res "distance" ) )