前面讲的符号表都是一个键对应一个值或多个键对应一个值,生活中还有很多应用需要一个键对应多个值,比如一个电影对应多个演员,这种结构被称为索引
下面的代码利用符号表实现索引和反向索引,可以把一个电影和多个演员关联,并且可以反过来把一个演员和多部电影关联起来
/*************************************************************************
* Compilation: javac LookupIndex.java
* Execution: java LookupIndex movies.txt "/"
* Dependencies: ST.java Queue.java In.java StdIn.java StdOut.java
* Data files: http://algs4.cs.princeton.edu/35applications/aminoI.txt
* http://algs4.cs.princeton.edu/35applications/movies.txt
*
* % java LookupIndex aminoI.txt ","
* Serine
* TCT
* TCA
* TCG
* AGT
* AGC
* TCG
* Serine
*
* % java LookupIndex movies.txt "/"
* Bacon, Kevin
* Animal House (1978)
* Apollo 13 (1995)
* Beauty Shop (2005)
* Diner (1982)
* Few Good Men, A (1992)
* Flatliners (1990)
* Footloose (1984)
* Friday the 13th (1980)
* ...
* Tin Men (1987)
* DeBoy, David
* Blumenfeld, Alan
* ...
*
*************************************************************************/
public class LookupIndex {
public static void main(String[] args) {
String filename = args[0];
String separator = args[1];
In in = new In(filename);
ST<String, Queue<String>> st = new ST<String, Queue<String>>();
ST<String, Queue<String>> ts = new ST<String, Queue<String>>();
while (in.hasNextLine()) {
String line = in.readLine();
String[] fields = line.split(separator);
String key = fields[0];
for (int i = 1; i < fields.length; i++) {
String val = fields[i];
if (!st.contains(key)) st.put(key, new Queue<String>());
if (!ts.contains(val)) ts.put(val, new Queue<String>());
st.get(key).enqueue(val);
ts.get(val).enqueue(key);
}
}
StdOut.println("Done indexing");
// read queries from standard input, one per line
while (!StdIn.isEmpty()) {
String query = StdIn.readLine();
if (st.contains(query))
for (String vals : st.get(query))
StdOut.println(" " + vals);
if (ts.contains(query))
for (String keys : ts.get(query))
StdOut.println(" " + keys);
}
}
}
package chapter3_5;
/*************************************************************************
* Compilation: javac FileIndex.java
* Execution: java FileIndex file1.txt file2.txt file3.txt ...
* Dependencies: ST.java SET.java In.java StdIn.java StdOut.java
* Data files: http://algs4.cs.princeton.edu/ex1.txt
* http://algs4.cs.princeton.edu/ex2.txt
* http://algs4.cs.princeton.edu/ex3.txt
* http://algs4.cs.princeton.edu/ex4.txt
*
* % java FileIndex ex*.txt
* age
* ex3.txt
* ex4.txt
* best
* ex1.txt
* was
* ex1.txt
* ex2.txt
* ex3.txt
* ex4.txt
*
* % java FileIndex *.txt
*
* % java FileIndex *.java
*
*************************************************************************/
import java.io.File;
//此用例为文件创建索引,把文件里的每个词和包含他的文件关联起来,可以通过词快速查找到所有包含它的文件
public class FileIndex {
public static void main(String[] args) {
// key = word, value = set of files containing that word
ST<String, SET<File>> st = new ST<String, SET<File>>();
// create inverted index of all files
StdOut.println("Indexing files");
for (String filename : args) {
StdOut.println(" " + filename);
File file = new File(filename);
In in = new In(file);
while (!in.isEmpty()) {
String word = in.readString();
if (!st.contains(word)) st.put(word, new SET<File>());
SET<File> set = st.get(word);
set.add(file);
}
}
// read queries from standard input, one per line
while (!StdIn.isEmpty()) {
String query = StdIn.readString();
if (st.contains(query)) {
SET<File> set = st.get(query);
for (File file : set) {
StdOut.println(" " + file.getName());
}
}
}
}
}
//向量和向量相乘,如果两个向量的大小是N,时间复杂度就是O(N),
//如果是矩阵和向量相乘,矩阵是N行N列,向量大小也是N,时间复杂度就是O(N^2)
//现实中的用例N往往很多,几百亿,几千亿,如GoogleRank算法N等于互联网中网页的总数,如果用常规的算法时间是无法接受的
//其实用例中矩阵往往是稀疏的,即矩阵中大多数项都是0,对应Google的应用每个网页包含其他网页的链接其实比较少
//所以不用把向量中的每一项都存储在数组中,改用符号表值只存储不为0的项,把这些项的数组index和它的值关联起来
//根据0与其他数相乘等于0,0加其他数等于其他数的特点,计算中向量和向量相乘的复杂度跟向量中非0项成正比,矩阵和向量相乘的时间与N和非0项个数成正比,代码如下
public class SparseVector {
private int d; // dimension
private ST<Integer, Double> st; // the vector, represented by index-value pairs
/**
* Initializes a d-dimensional zero vector.
* @param d the dimension of the vector
*/
public SparseVector(int d) {
this.d = d;
this.st = new ST<Integer, Double>();
}
/**
* Sets the ith coordinate of this vector to the specified value.
*
* @param i the index
* @param value the new value
* @throws IllegalArgumentException unless i is between 0 and d-1
*/
public void put(int i, double value) {
if (i < 0 || i >= d) throw new IllegalArgumentException("Illegal index");
if (value == 0.0) st.delete(i);
else st.put(i, value);
}
/**
* Returns the ith coordinate of this vector.
*
* @param i the index
* @return the value of the ith coordinate of this vector
* @throws IllegalArgumentException unless i is between 0 and d-1
*/
public double get(int i) {
if (i < 0 || i >= d) throw new IllegalArgumentException("Illegal index");
if (st.contains(i)) return st.get(i);
else return 0.0;
}
/**
* Returns the number of nonzero entries in this vector.
*
* @return the number of nonzero entries in this vector
*/
public int nnz() {
return st.size();
}
/**
* Returns the dimension of this vector.
*
* @return the dimension of this vector
* @deprecated Replaced by {@link #dimension()}.
*/
@Deprecated
public int size() {
return d;
}
/**
* Returns the dimension of this vector.
*
* @return the dimension of this vector
*/
public int dimension() {
return d;
}
/**
* Returns the inner product of this vector with the specified vector.
*
* @param that the other vector
* @return the dot product between this vector and that vector
* @throws IllegalArgumentException if the lengths of the two vectors are not equal
*/
public double dot(SparseVector that) {
if (this.d != that.d) throw new IllegalArgumentException("Vector lengths disagree");
double sum = 0.0;
// iterate over the vector with the fewest nonzeros
if (this.st.size() <= that.st.size()) {
for (int i : this.st.keys())
if (that.st.contains(i)) sum += this.get(i) * that.get(i);
}
else {
for (int i : that.st.keys())
if (this.st.contains(i)) sum += this.get(i) * that.get(i);
}
return sum;
}
/**
* Returns the inner product of this vector with the specified array.
*
* @param that the array
* @return the dot product between this vector and that array
* @throws IllegalArgumentException if the dimensions of the vector and the array are not equal
*/
public double dot(double[] that) {
double sum = 0.0;
for (int i : st.keys())
sum += that[i] * this.get(i);
return sum;
}
/**
* Returns the magnitude of this vector.
* This is also known as the L2 norm or the Euclidean norm.
*
* @return the magnitude of this vector
*/
public double magnitude() {
return Math.sqrt(this.dot(this));
}
/**
* Returns the Euclidean norm of this vector.
*
* @return the Euclidean norm of this vector
* @deprecated Replaced by {@link #magnitude()}.
*/
@Deprecated
public double norm() {
return Math.sqrt(this.dot(this));
}
/**
* Returns the scalar-vector product of this vector with the specified scalar.
*
* @param alpha the scalar
* @return the scalar-vector product of this vector with the specified scalar
*/
public SparseVector scale(double alpha) {
SparseVector c = new SparseVector(d);
for (int i : this.st.keys()) c.put(i, alpha * this.get(i));
return c;
}
/**
* Returns the sum of this vector and the specified vector.
*
* @param that the vector to add to this vector
* @return the sum of this vector and that vector
* @throws IllegalArgumentException if the dimensions of the two vectors are not equal
*/
public SparseVector plus(SparseVector that) {
if (this.d != that.d) throw new IllegalArgumentException("Vector lengths disagree");
SparseVector c = new SparseVector(d);
for (int i : this.st.keys()) c.put(i, this.get(i)); // c = this
for (int i : that.st.keys()) c.put(i, that.get(i) + c.get(i)); // c = c + that
return c;
}
/**
* Returns a string representation of this vector.
* @return a string representation of this vector, which consists of the
* the vector entries, separates by commas, enclosed in parentheses
*/
public String toString() {
StringBuilder s = new StringBuilder();
for (int i : st.keys()) {
s.append("(" + i + ", " + st.get(i) + ") ");
}
return s.toString();
}
/**
* Unit tests the {@code SparseVector} data type.
*
* @param args the command-line arguments
*/
public static void main(String[] args) {
SparseVector a = new SparseVector(10);
SparseVector b = new SparseVector(10);
a.put(3, 0.50);
a.put(9, 0.75);
a.put(6, 0.11);
a.put(6, 0.00);
b.put(3, 0.60);
b.put(4, 0.90);
StdOut.println("a = " + a);
StdOut.println("b = " + b);
StdOut.println("a dot b = " + a.dot(b));
StdOut.println("a + b = " + a.plus(b));
}
}
/******************************************************************************
* Compilation: javac SparseMatrix.java
* Execution: java SparseMatrix
* Dependencies: StdOut.java
*
* A sparse, square matrix, implementing using two arrays of sparse
* vectors, one representation for the rows and one for the columns.
*
* For matrix-matrix product, we might also want to store the
* column representation.
*
******************************************************************************/
public class SparseMatrix {
private int n; // n-by-n matrix
private SparseVector[] rows; // the rows, each row is a sparse vector
// initialize an n-by-n matrix of all 0s
public SparseMatrix(int n) {
this.n = n;
rows = new SparseVector[n];
for (int i = 0; i < n; i++)
rows[i] = new SparseVector(n);
}
// put A[i][j] = value
public void put(int i, int j, double value) {
if (i < 0 || i >= n) throw new IllegalArgumentException("Illegal index");
if (j < 0 || j >= n) throw new IllegalArgumentException("Illegal index");
rows[i].put(j, value);
}
// return A[i][j]
public double get(int i, int j) {
if (i < 0 || i >= n) throw new IllegalArgumentException("Illegal index");
if (j < 0 || j >= n) throw new IllegalArgumentException("Illegal index");
return rows[i].get(j);
}
// return the number of nonzero entries (not the most efficient implementation)
public int nnz() {
int sum = 0;
for (int i = 0; i < n; i++)
sum += rows[i].nnz();
return sum;
}
// return the matrix-vector product b = Ax
public SparseVector times(SparseVector x) {
if (n != x.size()) throw new IllegalArgumentException("Dimensions disagree");
SparseVector b = new SparseVector(n);
for (int i = 0; i < n; i++)
b.put(i, rows[i].dot(x));
return b;
}
// return this + that
public SparseMatrix plus(SparseMatrix that) {
if (this.n != that.n) throw new RuntimeException("Dimensions disagree");
SparseMatrix result = new SparseMatrix(n);
for (int i = 0; i < n; i++)
result.rows[i] = this.rows[i].plus(that.rows[i]);
return result;
}
// return a string representation
public String toString() {
String s = "n = " + n + ", nonzeros = " + nnz() + "\n";
for (int i = 0; i < n; i++) {
s += i + ": " + rows[i] + "\n";
}
return s;
}
// test client
public static void main(String[] args) {
SparseMatrix A = new SparseMatrix(5);
SparseVector x = new SparseVector(5);
A.put(0, 0, 1.0);
A.put(1, 1, 1.0);
A.put(2, 2, 1.0);
A.put(3, 3, 1.0);
A.put(4, 4, 1.0);
A.put(2, 4, 0.3);
x.put(0, 0.75);
x.put(2, 0.11);
StdOut.println("x : " + x);
StdOut.println("A : " + A);
StdOut.println("Ax : " + A.times(x));
StdOut.println("A + A : " + A.plus(A));
}
}