Micrograd 的内容其实和我学矩阵微分的哪些部分有些重合,当然 mit 的偏数学计算一些,用的也是很 fancy 的 Julia 语言,Andrej Karpathypython 实现。 自动微分(Automatic Differentiation):实现篇 三点水 还可以看一下拓扑排序的算法介绍。

其实就是符号推导和数值计算,介绍一下框架的底层大概是如何实现的。

看懂很简单,但是具体上手写,你别说要点还很多。

首先是 python 的类有很多魔法方法可以用到,都是我之前不知道的。(还有很多有趣的机制,比如列表推导式,)

只需要 100 行代码就可以构建一个简单的计算框架,本仓库我将仿照 karpathy 来实现一个 micrograd 的自动微分框架并且写一个 demo 来使用我实现的自动微分框架(先按照 karpathy 的 demo 来一个,然后换个数据集写一个 demo)

一些 python 的技巧

Python 的 General-Purpose Special Methods(也叫魔法方法、双下杠方法、dunder methods)是让你自定义类行为的核心机制。通过实现这些特殊方法,你的类可以像内置类型(int、str、list 等)一样自然地支持运算符、打印、长度、迭代、上下文管理器等操作。

下面系统地介绍最常用、最重要的通用特殊方法(按官方文档的分类和常见使用场景整理),并配上真实可运行的例子。

1. 基本定制(对象字符串表示)

方法作用常用场景
__repr__(self)官方字符串表示,repr(obj)调试时最好能唯一标识对象
__str__(self)人类可读的字符串,str(obj)print给最终用户看
__format__(self, spec)自定义 format() 和 f-string 行为{value:.2f} 这种格式化
__bytes__(self)bytes(obj)二进制协议
class Value:
    def __init__(self, data):
        self.data = data
    
    def __repr__(self):
        return f"Value(data={self.data})"
    
    def __str__(self):
        return f"[Value {self.data}]"
    
    def __format__(self, spec):
        if spec.endswith('h'):
            return f"Humorous {self.data}!!"
        return format(self.data, spec)
 
v = Value(3.14159)
print(repr(v))        # Value(data=3.14159)
print(v)              # [Value 3.14159]
print(f"{v:.2f}")     # 3.14
print(f"{v:h}")       # Humorous 3.14159!!

2. 运算符重载(让对象支持 + - * / 等)

2.1 二元运算符

方法对应操作符反向/增强赋值版本
__add__(self, other)self + other__iadd__
__sub__(self, other)self - other__isub__
__mul__(self, other)self * other__imul__
__truediv__(self, other)self / other__itruediv__
__floordiv__(self, other)self // other__ifloordiv__
__mod__(self, other)self % other__imod__
__pow__(self, other)self ** other__ipow__
__matmul__(self, other)self @ other__imatmul__(矩阵乘法)

2.2 反向运算符(当左操作数不支持时调用右操作数的)

方法对应右操作数的情形
__radd__(self, other)other + self
__rsub__(self, other)other - self
…(所有二元运算符都有 r 版本)

2.3 增强赋值同上 __iadd__

class Value:
    def __init__(self, data):
        self.data = data = data
    
    def __add__(self, other):
        # 支持 Value + Value 和 Value + float/int
        if isinstance(other, Value):
            return Value(self.data + other.data)
        else:
            return Value(self.data + other)
    
    def __radd__(self, other):
        # 让 5 + Value(3) 也能工作
        return Value(other + self.data)
    
    def __repr__(self):
        return f"Value({self.data})"
 
a = Value(2.0)
b = Value(-3.0)
print(a + b)        # Value(-1.0)
print(a + 10)       # Value(12.0)
print(10 + a)       # Value(12.0)  ← 触发 __radd__

3. 比较运算符

方法操作符
__lt__(self, other)<
__le__(self, other)<=
__eq__(self, other)==
__ne__(self, other)!=
__gt__(self, other)>
__ge__(self, other)>=

Python 3.10+ 推荐使用 @total_ordering 装饰器,只实现 __eq__ 和一个比较方法即可自动补全其余。

4. 容器相关特殊方法

方法作用
__hash__(self)让对象可哈希,能放进 set/dict 做 key(必须同时实现 __eq__
__bool__(self)bool(obj),默认如果没实现则看 __len__ 是否为 0
__len__(self)len(obj)
__getitem__(self, key)obj[key],支持切片
__setitem__(self, key, val)obj[key] = val
__delitem__(self, key)del obj[key]
__iter__(self)让对象可迭代,返回迭代器
__next__(self)迭代器协议(通常配合 __iter__
__contains__(self, item)item in obj

5. 上下文管理器(with 语句)

class ManagedFile:
    def __init__(self, name):
        self.name = name
    def __enter__(self):
        self.file = open(self.name, 'w')
        return self.file
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.file:
            self.file.close()
 
with ManagedFile('hello.txt') as f:
    f.write('hello')

6. 可调用对象

class Adder:
    def __init__(self, n):
        self.n = n
    def __call__(self, x):
        return self.n + x
 
add5 = Adder(5)
print(add5(10))   # 15

7. 属性访问控制

方法作用
__getattr__(self, name)属性不存在时调用
__getattribute__(self, name)所有属性访问都先走这里(小心递归)
__setattr__(self, name, val)设置属性时调用
__delattr__(self, name)删除属性时调用
__dir__(self)dir(obj) 返回的列表

8. 描述符(descriptor)协议(高级)

__get__, __set__, __delete__, __set_name__ —— 用于实现 property、classmethod、staticmethod 等底层机制。

9. 协程相关(Python 3.5+)

async def coro():
    return 42
 
class AwaitableValue:
    def __await__(self):
        return coro().__await__()

小结:最常用的 15 个魔法方法(背下来就够日常 90% 场景)

__init__      # 构造
__repr__      # 调试友好
__str__       # 打印友好
__add__ / __mul__ ...   # 运算符
__len__       # len()
__getitem__   # obj[key]、for 循环
__setitem__
__iter__ / __next__   # 迭代器
__eq__ / __lt__ ...   # 比较、排序
__hash__      # 可哈希
__call__      # 可调用
__enter__ / __exit__  # with
__bool__      # bool()
__contains__ # in 操作符
__del__       # 对象销毁时(慎用)

掌握这些特殊方法,你的类就能像 Python 内置类型一样丝滑地融入整个语言生态。这也是构建像 NumPy、Pandas、PyTorch 这些库中“感觉像原生类型”的关键技术。

__repr__ 是 Python 中最重要、最常被问到的魔法方法之一,它的核心使命是:给开发者(尤其是调试时)一个“一眼就能看懂这个对象是什么”的字符串

官方说法是:

__repr__ 应该返回一个可以尽可能让 eval(repr(obj)) == obj 成立的字符串表示(虽然不是强制,但这是最佳实践)。

通俗一点说:__repr__ 就是对象的“身份证”

调用方式场景看到的结果来自哪里?
直接在交互式 shell 输入变量名你正在调试代码__repr__
print(repr(obj))你想看官方表示__repr__
容器里放了你的对象(如 list、dict、set)打印整个容器时自动调用每个元素的 __repr__
IDE、Jupyter、调试器显示变量几乎所有开发工具都依赖 __repr__

__repr__ vs 烂 __repr__ 对比

# 坏例子(默认行为)
class Value:
    def __init__(self, data):
        self.data = data
 
v = Value(3.14)
print(v)                    # 输出: <__main__.Value object at 0x7f8b2c0a3d90>
# 完全看不出 data 是多少!调试噩梦!
# 好例子(你刚才看到的)
class Value:
    def __init__(self, data):
        self.data = data
    
    def __repr__(self):
        return f"Value(data={self.data}) "
 
v = Value(3.14)
print(v)                    # 输出: Value(data=3.14)
print([v, v])               # 输出: [Value(data=3.14), Value(data=3.14)]

在 REPL 或 Jupyter 里直接敲变量名:

>>> a = Value(-2.0)
>>> b = Value(3.0)
>>> a + b
Value(data=1.0)        # 立刻知道结果是对的!

黄金法则(几乎所有大厂、知名库都遵守)

  1. __repr__ 必须返回字符串(str)
  2. 内容要精确、无歧义,最好能唯一标识这个对象
  3. 尽量让 eval(repr(obj)) == obj 成立(特别是自定义容器、数值类)
  4. 不要怕写长,但要信息完整(比如包含 id、形状、关键字段)

几个著名库的 __repr__ 例子

import numpy as np
import pandas as pd
import torch
 
np.array([1,2,3])
# array([1, 2, 3])
 
pd.DataFrame({'A':[1,2,3]})
#    A
# 0  1
# 1  2
# 2  3
 
torch.tensor([1.0, 2.0])
# tensor([1., 2.])
 
# 它们全都是靠精心编写的 __repr__ 才这么好看!

小结:一句话记住

__repr__ 是写给开发者(尤其是正在 debug 的自己)看的“对象身份证”,写得好能让你调试效率翻 10 倍!

所以以后每次写 class,都养成习惯第一时间加上一个清晰的 __repr__,未来的你会在半夜 debug 的时候疯狂感谢现在的你。

__init__.py

__init__.py 这个文件的作用可以一句话概括:

它把一个目录变成“包(package)”,让你可以 import 它,并在里面做一些包级别的初始化/导出。

下面分点说清楚 👇


1. 最核心的作用:把目录变成包

有这么个目录结构:

project/
    mypkg/
        __init__.py
        core.py

有了 __init__.py 之后,你就可以:

import mypkg.core
from mypkg import core

Python 解释器在执行 import mypkg.core 时,大概会这样找:

  1. sys.path 里找到名为 mypkg 的目录;

  2. 看这个目录里有没有 __init__.py

    • 有 → 这是一个 ,先执行 mypkg/__init__.py,然后再去找 core.py
    • 没有 → 在老版本 Python 里就不是包(会导入失败),在新版本里可能被当成“命名空间包”,但用法稍有区别(后面再说)。

所以:__init__.py = 这个目录是一个 Python 包,而不是普通文件夹。


2. 它里面可以写什么?常见用途

__init__.py 不是一定要空的,它其实就是一个普通的 Python 模块文件,里面可以写任何 Python 代码。常见做法有:

2.1 包初始化代码(轻量级的)

第一次 import 包名 时会执行一遍,比如:

# mypkg/__init__.py
print("mypkg imported")  # 一般不推荐这么干,只是例子

实际工程里常见的是一些轻量的东西:

  • 定义包的版本号:

    __version__ = "1.2.3"
  • 配置日志的基本格式;

  • 做一些简单的检查(比如依赖版本提示);

一般不建议在 __init__.py 里做太重的事情(比如大量 IO、网络请求),否则一 import 整个包就变慢。


2.2 控制“对外暴露什么”:重导出和 __all__

很多包会在 __init__.py 里,把内部子模块里的类/函数“捞出来”,方便用户用:

# mypkg/__init__.py
from .core import Value, Foo
 
__all__ = ["Value", "Foo"]  # 控制 from mypkg import * 时导出的名字

这样别人就可以:

from mypkg import Value   # 而不用关心 core.py 的存在

好处是:

  • 对外暴露一个简洁的 API;
  • 你可以在后面重构内部目录结构,只要 __init__.py 里保持这些导出不变,用户代码就不用改。

3. 那为啥“很多工程目录都有这个文件”?

原因其实就是:

  1. 每一层目录都想当包

    • 比如:

      myproj/
          myproj/
              __init__.py
              models/
                  __init__.py
                  user.py
                  order.py
              utils/
                  __init__.py
                  io.py
    • 这样你就可以写:

      from myproj.models import user
      from myproj.utils.io import read_file
    • 以及在包内部用相对导入:

      # models/user.py
      from ..utils.io import read_file
  2. 兼容性和工具支持

    • 老版本 Python(3.3 之前)没有 __init__.py 就不会把目录当包;
    • 很多工具(IDE、测试框架、打包工具)也习惯以 __init__.py 为标志来判断“这里是个包”;
    • 所以很多项目干脆每个包目录都放一个(哪怕是空的)__init__.py,简单粗暴,最省心。

4. 和类里的 __init__ 有什么关系?

  • __init__.py文件名,是“包模块”的入口;
  • __init__类的方法名,是“实例初始化函数”。

它们只是名字长得像,本质上完全是两回事:

  • 一个作用于 (目录);
  • 一个作用于 对象实例

5. 现在不是可以没有 __init__.py 吗?

是的,从 Python 3.3 开始有了 命名空间包(namespace package) 的概念:

  • 某些情况下,一个包可以没有 __init__.py,Python 也能把它当成包;
  • 但这一般用于比较复杂的分发场景(比如同一个包名被拆成多个独立安装的子项目)。

在日常项目里,大部分人还是:

  • 习惯性地在需要当包的目录里放一个(即便是空的)__init__.py
  • 这样最直观,也兼容所有 Python 版本和各种工具。

总结一句

  • __init__.py 用来:把目录变成包 + 做包级别的初始化/导出

  • 很多工程目录都有,是因为:

    • 想让这些目录可 import
    • 方便写相对导入;
    • 兼容性好、工具支持好。

如果你愿意,我可以根据你现在的项目目录,帮你一起设计哪些地方应该有 __init__.py,顺便整理一下导入结构,让代码更清晰。

拓扑排序

其实就是需要不违反依赖顺序的调用一遍有向无环图(DAG)(其实就是后序遍历)

def backward(self):
	# topological order all of the children in the graph
	topo = []
	visited = set()
	def build_topo(v: Value):
		if v not in visited:
			vistied.add(v)
			for child in v._prev:
				build_topo(child)
			topo.append(v)
	build_topo(self)
	# 获得正向的topo排序
	self.grad = 1.0
	# 开始反向传播(所以还是需要reversed一下正向的topo排序)
	for node in reversed(topo):
		node._backward()
 
 

多次调用的 bug

Karpathy 举了一个简单的例子来说明

a = Value(3.0, label='a')
b = a + a   ; b.label = 'b'
b.backward()
draw_dot(b)
 

如果是如下就会出 bug,因为 a 相当于被调用两次(这里在一个式子中被调用两次相当于 self 和 other 是一个对象,也就是被 override 了)(还有的时候是在多个层级中被调用多次,这时候如果用 = 号依旧是被 override,所以注意要累加每次被调用时传递过来的梯度)

def __add__(self, other):
  out = Value(self.data + other.data, (self, other), '+')
  
  def _backward():
    self.grad  = 1.0 * out.grad
    other.grad = 1.0 * out.grad
  out._backward = _backward
  
  return out
 

修改成这样

 
def __add__(self, other):
  out = Value(self.data + other.data, (self, other), '+')
  
  def _backward():
    self.grad += 1.0 * out.grad
    other.grad += 1.0 * out.grad
  out._backward = _backward
  
  return out
 

要支持复杂的输入

Karpathy 的教程非常细致入微啊,确实是得看视频,大模型直接讲他的成品代码,完全没太讲清楚其中的细枝末节(当然也是因为我对 python 的了解不够)

下面我系统、清晰地讲解 Python 里的 “右侧运算符特殊方法(reverse operator methods)”,例如 __radd____rmul____rsub__ 等它们是什么、为什么存在、什么时候被调用。

这类方法是 Python 的 双目运算符重载机制的重要组成,在你写微分引擎(比如 micrograd)、矩阵库、Tensor 类时非常关键。


✨ 一、Python 的运算符重载是“双向”的

当你写:

a + b

Python 内部其实执行这两步:

  1. 调用左边操作数 a.__add__(b)
  2. 如果第一步返回 NotImplemented,或者类型不兼容,则调用右边的 b.__radd__(a)

一个图就能理解:

a + b
│
├── 尝试 a.__add__(b)
│          │
│          ├── 如果成功则用它
│          └── 如果失败
│
└── 尝试 b.__radd__(a)
           │
           └── 若成功则返回结果

✨ 二、右侧运算符方法(rxxx)为什么存在?

原因:左右两边可能不是同一类型

例如:

3 + myValue

计算顺序:

3.__add__(myValue)  → int.__add__ 不认识 Value → 返回 NotImplemented
myValue.__radd__(3) → 被调用

所以:

如果你希望 3 + your_object 也能工作, 就必须实现 __radd__

类似地:

  • __rmul__ → 应对 3 * your_obj
  • __rsub__ → 应对 3 - your_obj
  • __rtruediv__ → 应对 3 / your_obj

✨ 三、常见的 forward / reverse 对

运算符左侧方法右侧方法(fallback)
a + b__add____radd__
a - b__sub____rsub__
a * b__mul____rmul__
a / b__truediv____rtruediv__
a // b__floordiv____rfloordiv__
a ** b__pow____rpow__

大家都有右侧版本!


✨ 四、一个实际例子:为什么 micrograd 需要 radd

在 micrograd 里你可能写:

sum(values)

Python 计算 sum 的方法是:

0 + values[0] + values[1] + ...

因为最开始是:

0 + Value(...)

0int,所以走:

0.__add__(Value) → 不认识 → NotImplemented
Value.__radd__(0) → 被调用

因此必须在 Value 里实现:

def __radd__(self, other):
    return self + other

否则 sum() 会报错。


✨ 五、一个更完整的示例(Value 类)

假设你有一个自定义类:

class Value:
    def __init__(self, data):
        self.data = data
 
    def __add__(self, other):
        print("左加被调用")
        other = other if isinstance(other, Value) else Value(other)
        return Value(self.data + other.data)
 
    def __radd__(self, other):
    	## reverse add本质上就是为了处理 1 + a -> a.__radd__(other)
        return self + other

为什么 python 不自动帮我们实现 _radd?(python 的有个哲学是,尽量避免像 C++那样有着太多的潜藏话语,保持简单和明白)(不是所有的运算符都适合自动实现反向,比如减法是不可以直接交换位置)

为什么 python 在遇到 1+a 的时候会自动尝试 __radd__ 方法,这是因为 python 的双向运算符查找机制(Fallback Mechanism)

Python 的运算符查找顺序 当执行 1 + a 时,Python 按以下顺序查找:

# 步骤1:尝试左操作数的 __add__
result = (1).__add__(a)  # int.__add__(a)
# 如果 int 不知道如何处理 Value 对象,返回 NotImplemented
 
# 步骤2:如果返回 NotImplemented,尝试右操作数的 __radd__
if result is NotImplemented:
    result = a.__radd__(1)  # Value.__radd__(1)
    # 这就是为什么 Python 会自动调用 __radd__

为什么需要这个机制?

Python

other = other if isinstance(other, Value) else Value(other)
# isinstance是为other检查是否为Value实例

测试:

x = Value(10)
y = Value(5)
 
z1 = x + y        # 调用 __add__
z2 = 3 + x        # 调用 __radd__

输出:

左加被调用
右加被调用

✨ 六、什么时候一定要实现右侧方法?

以下情况几乎必须实现:

1. 支持常数和对象混合运算

例如:

2 * my_tensor

就需要:

__rmul__

2. 支持 sum ()

因为 sum 从 0 开始累积。

3. 左方法不对称时

比如你希望:

x - y

3 - x

都能正常工作,则必须实现:

__rsub__

✨ 七、右侧方法完全等价于交换顺序吗?

不完全!

你可以在右侧方法里定义不同的行为

例如 Python 中:

int.__add__(numpy_array) → NotImplemented
numpy_array.__radd__(int) → 做数组广播加法

这是合理且常见的。


✨ 八、一个容易忽略的细节

如果 add 返回的是 NotImplemented(不是错误),Python 才会继续尝试 __radd__

例如:

def __add__(self, other):
    return NotImplemented

表示:“我不处理这种加法,让右侧试试”。


📌 总结(一句话版)

右侧运算符方法(radd / rmul 等)是当左侧运算符不支持时的 fallback,它让表达式中的右侧对象也能决定运算逻辑。

你自定义的类如果希望支持:

3 + x
2 * x
10 / x
sum(xs)

就必须实现相应的 __rxxx__


Build out a neural net library in micrograd

过去对 MLP 算法的认识太浅了(先让 Gemini 给我可视化解释一下)

class Neuron:
    def __init__(self, nin):
        self.w = [Value(random.uniform(-1, 1)) for _ in range(nin)]
        self.b = Value(random.uniform(-1, 1))
        
    def __call__(self, x):
        # x is a vector
        act = sum((wi*xi for wi, xi in zip(self.w, x)), self.b)
        

Python 的 list comprehension 列表推导式

这句话运用了 Python 中非常常用且强大的语法特性:列表推导式 (List Comprehension)

这行代码的具体含义是:创建一个包含 nin 个元素的列表,其中每个元素都是一个初始化为随机数的 Value 对象,并将这个列表赋值给实例变量 self.w

下面我将其拆解为 4 个部分详细解释:

1. 列表推导式 [... for ... in ...]

这是 Python 创建列表的一种简洁方式。

  • 语法结构[表达式 for 变量 in 可迭代对象]
  • 对应代码[Value(...) for _ in range(nin)]
  • 作用:它相当于一个简化的 for 循环。如果用普通写法,它等同于:
    temp_list = []
    for _ in range(nin):
        temp_list.append(Value(random.uniform(-1, 1)))
    self.w = temp_list

2. range(nin)

  • 含义:生成一个从 0 到 nin - 1 的数字序列。
  • 作用:在这里,它主要用于控制循环的次数。nin 通常代表 “number of inputs”(输入特征的数量)。如果 nin 是 3,循环就会执行 3 次。

3. 下划线 _ 作为变量名

  • 含义:在 Python 中,如果我们循环时不需要使用循环变量的值(比如这里我们只是单纯想循环 nin 次,而不关心当前是第几次),通常约定俗成地使用下划线 _ 来作为变量名。
  • 作用:告诉阅读代码的人:“这是一个占位符,这个变量的值我们在后面不会用到。”

4. Value(random.uniform(-1, 1)) (循环体内的表达式)

这是列表生成的每一个元素的内容:

  • random.uniform(-1, 1): 生成一个范围在 -1.0 到 1.0 之间的随机浮点数。
  • Value(...): 将这个随机数封装进 Value 类中。在 Micrograd(你的项目背景)中,Value 通常是用来支持自动微分的包装类。

总结

这行代码的完整逻辑是:

“循环 nin 次(for _ in range(nin)),每次生成一个 -1 到 1 之间的随机数并包装成 Value 对象(Value(random.uniform(-1, 1))),最后把这 nin 个对象组成一个列表([...]),赋值给 self.w(即神经元的权重列表)。”

我对 python 语法了解的太少了

不熟悉这种句式

inputs = [list(map(Value, xrow)) for xrow in Xb]

先把结论说在前面:

  • map 是一个“批量套用函数”的工具:对序列里的每个元素都执行同一个函数。
  • 可迭代对象(iterable):能被 for 循环的东西,比如 liststrrange
  • 迭代器(iterator):能被一个一个取出元素的东西,有 __next__ 方法,比如 map 返回的就是一个迭代器。

下面一点点讲 👇


一、什么是“可迭代对象”(iterable)?

直观理解

能够“被一个一个遍历”的对象,就是“可迭代对象”。

比如:

lst = [1, 2, 3]     # 列表
s = "hello"         # 字符串
r = range(5)        # range 对象
d = {"a": 1, "b": 2}# 字典

这些都可以写:

for x in lst: ...
for ch in s: ...
for i in r: ...
for key in d: ...

只要能用 for ... in ... 遍历,就是 iterable

更底层一点(你知道就好,不用死记):

  • 一个对象如果实现了 __iter__() 方法,就被认为是“可迭代的”。
  • for 循环会自动调用:iter(obj),得到一个 迭代器

二、什么是“迭代器”(iterator)?

迭代器是一个“能记住当前位置,并且每次给你下一个元素”的对象。

特点:

  1. __iter__() 方法(返回自己)
  2. __next__() 方法(每调用一次,返回下一个元素;没元素时抛出 StopIteration

我们可以手工玩一下迭代器:

lst = [10, 20, 30]
it = iter(lst)      # 从可迭代对象 lst 得到一个迭代器
 
print(next(it))     # 10
print(next(it))     # 20
print(next(it))     # 30
# 再 next(it) 就会抛 StopIteration

for 循环本质上就在偷偷调用 next()


三、map 是什么?

基本形式

map(function, iterable)
  • function:一个函数,接收一个元素,返回一个新值
  • iterable:一个可迭代对象(列表、元组、字符串、range 等)

作用:

构造一个“迭代器”,这个迭代器会把 iterable 里的每个元素,依次丢给 function,得到一串新结果。

但注意:map 本身不会立刻执行完全部计算,它返回的是一个 迭代器

简单例子

# 把列表里的每个数都平方
def square(x):
    return x * x
 
nums = [1, 2, 3, 4]
m = map(square, nums)
print(m)   # <map object at 0x...>,这是一个迭代器

如果想看到结果,有两种常见做法:

① 用 list 把它“物化”为列表

result = list(m)
print(result)  # [1, 4, 9, 16]

你在之前代码里看到的 list(map(Value, xrow)) 就是这种用法。

② 用 for 循环消耗 map 迭代器

for item in map(square, nums):
    print(item)

四、map 和列表推导式的对比

很多时候,用 列表推导式 更直观,比如:

# 用 map
result = list(map(square, nums))
 
# 用列表推导式(等价)
result = [square(x) for x in nums]

再比如你前面那行:

inputs = [list(map(Value, xrow)) for xrow in Xb]

其实可以理解为:

inputs = [[Value(x) for x in xrow] for xrow in Xb]

两种都对,只是风格不同:

  • 列表推导式:更 Pythonic,更好读。
  • map:写起来更短,跟函数式编程的风格更像。

五、map 支持多个可迭代对象

map 还可以同时接多个 iterable:

def add(a, b):
    return a + b
 
xs = [1, 2, 3]
ys = [10, 20, 30]
 
result = list(map(add, xs, ys))
print(result)  # [11, 22, 33]

原理:

  • 第一次:add(1, 10)
  • 第二次:add(2, 20)
  • 第三次:add(3, 30)

六、从“可迭代对象”到“迭代器”的链条

把你学过的串一下:

  1. list / tuple / str / range / dict 等等是 iterable

  2. iter(iterable) ⇒ 得到一个 iterator

  3. map(function, iterable) ⇒ 直接生成一个 iterator

  4. for x in something:

    • 如果 something 是 iterable:Python 会自动调用 iter(something) 得到迭代器
    • 然后反复调用 next() 取元素,直到 StopIteration

你可以把这些看成是“迭代协议”的一套机制。


七、结合你之前的代码再看一眼

那行代码:

inputs = [list(map(Value, xrow)) for xrow in Xb]

按现在的理解:

  1. Xb 是一个可迭代对象(可以被 for xrow in Xb 遍历)
  2. 每一个 xrow 本身又是一个可迭代对象(比如一行向量)
  3. map(Value, xrow) 生成一个迭代器,它会对 xrow 里的每个元素执行 Value(元素)
  4. list( ... ) 把这个迭代器中的结果“取出来”组成一个列表
  5. 外层列表推导式 [...] for xrow in Xb 收集所有这些行

结果是:inputs 是“列表(每个元素又是 Value 对象组成的列表)”。


Python 的上下文管理器

with gzip.open(image_filename, 'rb') as f_img:
	...
  • with
    • 帮你自动做”打开” “使用” “关闭”资源(比如文件,网络连接)