改进方法:
将
dataset = fetch_20newsgroups(categories=categories)
改为:
dataset = _fetch_20newsgroups(categories=categories)
并添加方法:
import shutil
import matplotlib as mpl
from sklearn.datasets.twenty_newsgroups import strip_newsgroup_header, strip_newsgroup_footer, strip_newsgroup_quoting
from sklearn.utils import check_random_state
mpl.use('Agg')
import os
import pickle
import codecs
import math
import networkx as nx
import pickle as pkl
import numpy as np
from itertools import product
from sklearn.datasets import fetch_20newsgroups, load_files
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import TSNE
from tqdm import tqdm
import tarfile
from matplotlib import pyplot as plt
TRAIN_FOLDER = "20news-bydate-train"
TEST_FOLDER = "20news-bydate-test"
batch_size = 32
categories = ['comp.graphics', 'rec.sport.baseball', 'talk.politics.guns']
def download_20newsgroups(target_dir):
"""Download the 20 newsgroups data and stored it as a zipped pickle."""
train_path = os.path.join(target_dir, TRAIN_FOLDER)
test_path = os.path.join(target_dir, TEST_FOLDER)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
archive_path = "./data/20news-bydate.tar.gz"
tarfile.open(archive_path, "r:gz").extractall(path=target_dir)
# Store a zipped pickle
cache = dict(train=load_files(train_path, encoding='latin1'),
test=load_files(test_path, encoding='latin1'))
shutil.rmtree(target_dir)
return cache
def _fetch_20newsgroups(data_home=None, subset='train', categories=None,
shuffle=True, random_state=42,
remove=(),
download_if_missing=True):
twenty_home = os.path.join("./", "20news_home")
cache = download_20newsgroups(target_dir=twenty_home)
if subset in ('train', 'test'):
data = cache[subset]
elif subset == 'all':
data_lst = list()
target = list()
filenames = list()
for subset in ('train', 'test'):
data = cache[subset]
data_lst.extend(data.data)
target.extend(data.target)
filenames.extend(data.filenames)
data.data = data_lst
data.target = np.array(target)
data.filenames = np.array(filenames)
else:
raise ValueError(
"subset can only be 'train', 'test' or 'all', got '%s'" % subset)
data.description = 'the 20 newsgroups by date dataset'
if 'headers' in remove:
data.data = [strip_newsgroup_header(text) for text in data.data]
if 'footers' in remove:
data.data = [strip_newsgroup_footer(text) for text in data.data]
if 'quotes' in remove:
data.data = [strip_newsgroup_quoting(text) for text in data.data]
if categories is not None:
labels = [(data.target_names.index(cat), cat) for cat in categories]
# Sort the categories to have the ordering of the labels
labels.sort()
labels, categories = zip(*labels)
mask = np.in1d(data.target, labels)
data.filenames = data.filenames[mask]
data.target = data.target[mask]
# searchsorted to have continuous labels
data.target = np.searchsorted(labels, data.target)
data.target_names = list(categories)
# Use an object array to shuffle: avoids memory copy
data_lst = np.array(data.data, dtype=object)
data_lst = data_lst[mask]
data.data = data_lst.tolist()
if shuffle:
random_state = check_random_state(random_state)
indices = np.arange(data.target.shape[0])
random_state.shuffle(indices)
data.filenames = data.filenames[indices]
data.target = data.target[indices]
# Use an object array to shuffle: avoids memory copy
data_lst = np.array(data.data, dtype=object)
data_lst = data_lst[indices]
data.data = data_lst.tolist()
return data