scikit-learn文本分类

import numpy as np
import pandas as pd

df = pd.read_csv('./smsspamcollection.tsv', sep='\t')
df.head()
df['label'].value_counts()

# split data set
from sklearn.model_selection import train_test_split
X = df['message']
y = df['label']
df.dropna(inplace=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# pipeline
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC

text_clf = Pipeline([('tfidf', TfidfVectorizer()),
                     ('clf', LinearSVC()),
])

# Feed the training data through the pipeline
text_clf.fit(X_train, y_train)  

# Prediction & Metrics
predictions = text_clf.predict(X_test)
from sklearn import metrics
print(metrics.confusion_matrix(y_test,predictions))
print(metrics.classification_report(y_test,predictions))

其他可选模型,如

from sklearn.naive_bayes import MultinomialNB

 

 

Leave a Reply

Your email address will not be published. Required fields are marked *