手写 Promise (TS 版本)

一直想手写一遍 Promise,看完了国内外各种教程后,决定用 TS 写一遍。写完之后感觉再写一篇繁琐的文章来重现完整的步骤,因此这篇文章就仅当做记录吧。

代码实现

下面的实现代码已通过 872 个 test cases

const PENDING = "pending";
const FULFILLED = "fulfilled";
const REJECTED = "rejected";

type PromiseResolveFn = (value: unknown) => void;
type PromiseRejectFn = (reason: unknown) => void;
type PromiseExecutor = (
  resolve: PromiseResolveFn,
  reject: PromiseRejectFn
) => void;
type ThenOnFulfilled = (value: unknown) => Promise<unknown> | unknown;
type ThenOnRejected = (reason: unknown) => Promise<unknown> | unknown;
type Thenable = {
  then: Function;
};

class Promise<T> {
  private status = PENDING;
  private value: unknown = undefined;
  private onResolvedCallbacks: Function[] = [];
  private onRejectedCallbacks: Function[] = [];

  constructor(executor: PromiseExecutor) {
    const resolve = (value: unknown) => {
      if (this.status === PENDING) {
        this.status = FULFILLED;
        this.value = value;

        this.onResolvedCallbacks.forEach((fn) => fn());
      }
    };

    const reject = (reason: unknown) => {
      if (this.status === PENDING) {
        this.status = REJECTED;
        this.value = reason;
      }
      this.onRejectedCallbacks.forEach((fn) => fn());
    };

    try {
      executor(resolve, reject);
    } catch (error) {
      reject(error);
    }
  }

  public then(
    onFulfilled: ThenOnFulfilled,
    onRejected: ThenOnRejected
  ): Promise<unknown> {
    if (typeof onFulfilled !== "function") {
      onFulfilled = (value) => value;
    }
    if (typeof onRejected !== "function") {
      onRejected = (reason) => {
        throw reason;
      };
    }

    const currentPromise = new Promise((resolve, reject) => {
      const onFulfilledFn = () =>
        runMircoTask(() => {
          {
            try {
              const fulfilledReturnValue = onFulfilled(this.value);
              promiseResolver(
                currentPromise,
                fulfilledReturnValue,
                resolve,
                reject
              );
            } catch (err) {
              reject(err);
            }
          }
        });

      const onRejectedFn = () =>
        runMircoTask(() => {
          {
            try {
              const nextMaybePromise = onRejected(this.value);
              promiseResolver(
                currentPromise,
                nextMaybePromise,
                resolve,
                reject
              );
            } catch (err) {
              reject(err);
            }
          }
        });
      if (this.status === FULFILLED) {
        onFulfilledFn();
      }
      if (this.status === REJECTED) {
        onRejectedFn();
      }
      if (this.status === PENDING) {
        this.onResolvedCallbacks.push(onFulfilledFn);
        this.onRejectedCallbacks.push(onRejectedFn);
      }
    });
    return currentPromise;
  }

  public catch(onRejected) {
    return this.then(null, onRejected);
  }

  public finally(callback) {
    return this.then(callback, callback);
  }

  public static resolve(value: any): Promise<any> {
    return new Promise((resolve) => resolve(value));
  }

  public static reject(reason) {
    return new Promise((_, reject) => reject(reason));
  }

  public static all(promises: Thenable[]): Promise<unknown> {
    let count = 0;
    const len = promises.length;
    const results: unknown[] = [];

    return new Promise((resolve, reject) => {
      if (typeof promises[Symbol.iterator] !== "function") {
        throw TypeError(`${promises} is not iterable`);
      }

      promises = Array.from(promises);
      if (promises.length === 0) {
        resolve(results);
      }
      const setValue = (value, index) => {
        count++;
        results[index] = value;

        if (count === len) {
          resolve(results);
        }
      };

      promises.forEach((promise, index) => {
        promise.then((value) => setValue(value, index), reject);
      });
    });
  }

  public static allSettled(promises: Thenable[]) {
    let count = 0;
    const len = promises.length;
    const results: unknown[] = [];

    return new Promise((resolve, reject) => {
      if (typeof promises[Symbol.iterator] !== "function") {
        throw TypeError(`${promises} is not iterable`);
      }

      promises = Array.from(promises);
      if (promises.length === 0) {
        resolve(results);
      }
      const setValue = (value, index) => {
        count++;
        results[index] = value;

        if (count === len) {
          resolve(results);
        }
      };

      promises.forEach((promise, index) => {
        promise.then(
          (value) =>
            setValue(
              {
                status: FULFILLED,
                value,
              },
              index
            ),
          (reason) =>
            setValue(
              {
                status: REJECTED,
                reason,
              },
              index
            )
        );
      });
    });
  }

  public static race(promises: Thenable[]) {
    return new Promise((resolve, reject) => {
      promises.forEach((promise) => promise.then(resolve, reject));
    });
  }
}

function promiseResolver(
  currentPromise: unknown,
  onFulfilledReturnValue: unknown,
  resolve: PromiseResolveFn,
  reject: PromiseRejectFn
) {
  if (currentPromise === onFulfilledReturnValue) {
    reject(new TypeError("can not resolve promise self"));
  } else {
    let resolveOrRejectHasCalled = false;
    if (isObjectOrFunction(onFulfilledReturnValue)) {
      try {
        const then = onFulfilledReturnValue.then;
        if (typeof then === "function") {
          then.call(
            onFulfilledReturnValue,
            (newMaybePromise) => {
              if (resolveOrRejectHasCalled) return;
              resolveOrRejectHasCalled = true;
              promiseResolver(
                onFulfilledReturnValue,
                newMaybePromise,
                resolve,
                reject
              );
            },
            (newMaybePromiseReason) => {
              if (resolveOrRejectHasCalled) return;
              resolveOrRejectHasCalled = true;
              reject(newMaybePromiseReason);
            }
          );
        } else {
          resolve(onFulfilledReturnValue);
        }
      } catch (err) {
        if (resolveOrRejectHasCalled) return;
        resolveOrRejectHasCalled = true;
        reject(err);
      }
    } else {
      resolve(onFulfilledReturnValue);
    }
  }
}

/**
 * 根据PromiseA+规范,这里的 runMircoTask 实现是依赖于平台的。在浏览器环境可用 queueMicrotask  MutationObserver 来模拟实现。如果没有对应的微任务接口则使用setTimeout 也是符合规范的。
 */
function runMircoTask(fn: Function) {
  queueMicrotask(fn);
}

function isObjectOrFunction(x: unknown): x is Object | Function {
  return x !== null && (typeof x === "function" || typeof x === "object");
}

export { Promise };

静态函数的单元测试

import { describe, test, expect, vi } from "vitest";
import { Promise } from "./Promise";

describe("Promise", () => {
  describe(".resolve()", () => {
    test("should resolve a value", async () => {
      console.log("hcl", Promise, Promise.resolve);
      const result = await Promise.resolve("success");
      expect(result).toBe("success");
    });
  });

  describe(".reject()", () => {
    test("should reject a value", async () => {
      try {
        await Promise.reject("failure");
      } catch (error) {
        expect(error).toBe("failure");
      }
    });
  });

  describe(".race()", () => {
    test("should resolve with the first resolved promise", async () => {
      const promise1 = new Promise((resolve) =>
        setTimeout(resolve, 500, "one")
      );
      const promise2 = new Promise((resolve) =>
        setTimeout(resolve, 100, "two")
      );
      const result = await Promise.race([promise1, promise2]);
      expect(result).toBe("two");
    });

    test("should reject with the first rejected promise", async () => {
      const promise1 = new Promise((_, reject) =>
        setTimeout(reject, 500, "one")
      );
      const promise2 = new Promise((_, reject) =>
        setTimeout(reject, 100, "two")
      );
      try {
        await Promise.race([promise1, promise2]);
      } catch (error) {
        expect(error).toBe("two");
      }
    });
  });

  describe(".all()", () => {
    test("should resolve when all promises have resolved", async () => {
      const promise1 = Promise.resolve("one");
      const promise2 = Promise.resolve("two");
      const result = await Promise.all([promise1, promise2]);
      expect(result).toEqual(["one", "two"]);
    });

    test("should reject when any promise rejects", async () => {
      const promise1 = Promise.resolve("one");
      const promise2 = Promise.reject("two");
      try {
        await Promise.all([promise1, promise2]);
      } catch (error) {
        console.log(error);
        expect(error).toBe("two");
      }
    });
  });

  describe(".catch()", () => {
    test("should handle rejections", async () => {
      const promise = Promise.reject("error");
      const result = await promise.catch((error) => error);
      expect(result).toBe("error");
    });
  });

  describe(".finally()", () => {
    test("should be called regardless of success or failure", async () => {
      const callback = vi.fn();
      await Promise.resolve("success").finally(callback);
      expect(callback).toHaveBeenCalledTimes(1);

      callback.mockClear();
      await Promise.reject("failure")
        .catch(() => {})
        .finally(callback);
      expect(callback).toHaveBeenCalledTimes(1);
    });
  });

  describe(".allSettled()", () => {
    test("should resolve when all promises have settled", async () => {
      const promise1 = Promise.resolve("one");
      const promise2 = Promise.reject("two");
      const results = await Promise.allSettled([promise1, promise2]);
      expect(results).toEqual([
        { status: "fulfilled", value: "one" },
        { status: "rejected", reason: "two" },
      ]);
    });
  });
});

CodeSandbox 在线体验