Python 教程

Python 教程 Python 简介 Python 历史 Python 下载安装 Python 入门 Python 语法 Python 注释 Python 变量 Python 数据类型 Python 数值类型 Python 类型转换 Python 字符串 Python 布尔值 Python 运算符 Python 列表 Python 元组 Python 集合 Python 字典 Python If...Else Python While 循环 Python For 循环 Python 函数 Python Lambda Python 数组 Python 类和对象 Python 继承 Python 迭代 Python 作用域 Python 模块 Python 日期时间 Python 数学运算 Python JSON Python 正则表达式 Python PIP Python Try...Except Python 用户输入 Python 字符串格式化

Python 文件处理

Python 文件处理 Python 打开文件 Python 创建/写入文件 Python 删除文件

Python NumPy

NumPy 简介 NumPy 入门 NumPy 创建数组 NumPy 数组索引 NumPy 数组裁切 NumPy 数据类型 NumPy 副本 vs 视图 NumPy 数组形状 NumPy 数组重塑 NumPy 数组迭代 NumPy 数组连接 NumPy 数组拆分 NumPy 数组搜索 NumPy 数组排序 NumPy 数组过滤 NumPy 随机数 NumPy ufunc 通用函数

Python SciPy

SciPy 简介 SciPy 入门 SciPy 常量 SciPy 优化器 SciPy 稀疏数据 SciPy 图表 SciPy 空间数据 SciPy Matlab 数组 SciPy 插值 SciPy 统计显着性检验

Python 机器学习

Machine 机器学习入门 Machine 平均中位数模式 Machine 标准差 Machine 百分位数 Machine 数据分布 Machine 正态数据分布 Machine 散点图 Machine 线性回归 Machine 多项式回归 Machine 多元回归 Machine 缩放 Machine 训练/测试 Machine 决策树

Python MySQL

MySQL 入门 MySQL Create Database MySQL Create Table MySQL Insert MySQL Select MySQL Where MySQL Order By MySQL Delete MySQL Drop Table MySQL Update MySQL Limit MySQL Join

Python MongoDB

MongoDB 入门 MongoDB 创建数据库 MongoDB 创建集合 MongoDB 插入 MongoDB 查找 MongoDB 查询 MongoDB 排序 MongoDB 删除 MongoDB 删除集合 MongoDB 更新 MongoDB 限制

Python 参考手册

Python 参考手册 Python 内置函数 Python 字符串方法 Python 列表/数组方法 Python 字典方法 Python 元组方法 Python 集合方法 Python 文件方法 Python 关键字 Python 内置异常 Python 词汇表

Python 模块参考

Python 随机模块 Python 请求模块 Python 统计模块 Python 数学模块 Python cMath模块

Python 如何使用

Python 删除列表重复项 Python 反转字符串 Python 添加两个数字

Python 高级教程

Python 常用指引 将Python2代码迁移到Python3 将扩展模块移植到 Python3 Curses 编程 描述器使用指南 函数式编程指引 日志常用指引 日志操作手册 正则表达式使用指南 套接字编程指南 排序指南 Unicode 指南 如何利用urllib包获取网络资源 Argparse 教程 ipaddress 模块介绍 Argument Clinic 的用法 使用DTrace和SystemTap检测CPython 对象注解属性的最佳实践

Python 实例

Python 实例 Python 编译器 Python 练习 Python 测验 NumPy 测验 SciPy 测验


机器学习 - 线性回归

回归

当您尝试找到变量之间的关系时,会用到术语"回归"(regression)。

在机器学习和统计建模中,这种关系用于预测未来事件的结果。


线性回归

线性回归使用数据点之间的关系在所有数据点之间画一条直线。

这条线可以用来预测未来的值。

在机器学习中,预测未来非常重要。


工作原理

Python 提供了一些方法来查找数据点之间的关系并绘制线性回归线。我们将向您展示如何使用这些方法而不是通过数学公式。

在下面的示例中,x 轴表示车龄,y 轴表示速度。我们已经记录了 13 辆汽车通过收费站时的车龄和速度。让我们看看我们收集的数据是否可以用于线性回归:

实例

首先绘制散点图:

import matplotlib.pyplot as plt

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

plt.scatter(x, y)
plt.show()

结果:

运行实例 »

实例

导入 scipy 并绘制线性回归线:

import matplotlib.pyplot as plt
from scipy import stats

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

slope, intercept, r, p, std_err = stats.linregress(x, y)

def myfunc(x):
  return slope * x + intercept

mymodel = list(map(myfunc, x))

plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

结果:

运行实例 »

实例解析

导入所需模块:

您可以在我们的 SciPy 教程中了解 SciPy 模块。

import matplotlib.pyplot as plt
from scipy import stats

创建表示 x 和 y 轴值的数组:

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

执行一个方法,该方法返回线性回归的一些重要键值:

slope, intercept, r, p, std_err = stats.linregress(x, y)

创建一个使用 slopeintercept 值的函数返回新值。这个新值表示相应的 x 值将在 y 轴上放置的位置:

def myfunc(x):
  return slope * x + intercept

通过函数运行 x 数组的每个值。这将产生一个新的数组,其中的 y 轴具有新值:

mymodel = list(map(myfunc, x))

绘制原始散点图:

plt.scatter(x, y)

绘制线性回归线:

plt.plot(x, mymodel)

显示图:

plt.show()


R-Squared

重要的是要知道 x 轴的值和 y 轴的值之间的关系有多好,如果没有关系,则线性回归不能用于预测任何东西。

该关系用一个称为 r 平方(r-squared)的值来度量。

r 平方值的范围是 0 到 1,其中 0 表示不相关,而 1 表示 100% 相关。

Python 和 Scipy 模块将为您计算该值,您所要做的就是将 x 和 y 值提供给它:

实例

我的数据在线性回归中的拟合度如何?

from scipy import stats

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

slope, intercept, r, p, std_err = stats.linregress(x, y)

print(r)
亲自试一试 »

注释: 结果 -0.76 表明存在某种关系,但不是完美的关系,但它表明我们可以在将来的预测中使用线性回归。


预测未来价值

现在,我们可以使用收集到的信息来预测未来的值。

例如:让我们尝试预测一辆拥有 10 年历史的汽车的速度。

为此,我们需要与上例中相同的 myfunc() 函数:

def myfunc(x):
  return slope * x + intercept

实例

预测一辆有 10年车龄的汽车的速度:

from scipy import stats

x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]

slope, intercept, r, p, std_err = stats.linregress(x, y)

def myfunc(x):
  return slope * x + intercept

speed = myfunc(10)

print(speed)
运行实例 »

该例预测速度为 85.6,我们也可以从图中读取:


糟糕的拟合度?

让我们创建一个实例,其中的线性回归并不是预测未来值的最佳方法。

实例

x 和 y 轴的这些值将导致线性回归的拟合度非常差:

import matplotlib.pyplot as plt
from scipy import stats

x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]

slope, intercept, r, p, std_err = stats.linregress(x, y)

def myfunc(x):
  return slope * x + intercept

mymodel = list(map(myfunc, x))

plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()

结果:

运行实例 »

以及 r-squared 值?

实例

您应该得到了一个非常低的 r-squared 值。

import numpy
from scipy import stats

x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y = [21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]

slope, intercept, r, p, std_err = stats.linregress(x, y)

print(r)
亲自试一试 »

结果:0.013 表示关系很差,并告诉我们该数据集不适合线性回归。