一直想手写一遍 Promise,看完了国内外各种教程后,决定用 TS 写一遍。写完之后感觉再写一篇繁琐的文章来重现完整的步骤,因此这篇文章就仅当做记录吧。
下面的实现代码已通过 872 个 test cases
const PENDING = "pending";
const FULFILLED = "fulfilled";
const REJECTED = "rejected";
type PromiseResolveFn = (value: unknown) => void;
type PromiseRejectFn = (reason: unknown) => void;
type PromiseExecutor = (
resolve: PromiseResolveFn,
reject: PromiseRejectFn
) => void;
type ThenOnFulfilled = (value: unknown) => Promise<unknown> | unknown;
type ThenOnRejected = (reason: unknown) => Promise<unknown> | unknown;
type Thenable = {
then: Function;
};
class Promise<T> {
private status = PENDING;
private value: unknown = undefined;
private onResolvedCallbacks: Function[] = [];
private onRejectedCallbacks: Function[] = [];
constructor(executor: PromiseExecutor) {
const resolve = (value: unknown) => {
if (this.status === PENDING) {
this.status = FULFILLED;
this.value = value;
this.onResolvedCallbacks.forEach((fn) => fn());
}
};
const reject = (reason: unknown) => {
if (this.status === PENDING) {
this.status = REJECTED;
this.value = reason;
}
this.onRejectedCallbacks.forEach((fn) => fn());
};
try {
executor(resolve, reject);
} catch (error) {
reject(error);
}
}
public then(
onFulfilled: ThenOnFulfilled,
onRejected: ThenOnRejected
): Promise<unknown> {
if (typeof onFulfilled !== "function") {
onFulfilled = (value) => value;
}
if (typeof onRejected !== "function") {
onRejected = (reason) => {
throw reason;
};
}
const currentPromise = new Promise((resolve, reject) => {
const onFulfilledFn = () =>
runMircoTask(() => {
{
try {
const fulfilledReturnValue = onFulfilled(this.value);
promiseResolver(
currentPromise,
fulfilledReturnValue,
resolve,
reject
);
} catch (err) {
reject(err);
}
}
});
const onRejectedFn = () =>
runMircoTask(() => {
{
try {
const nextMaybePromise = onRejected(this.value);
promiseResolver(
currentPromise,
nextMaybePromise,
resolve,
reject
);
} catch (err) {
reject(err);
}
}
});
if (this.status === FULFILLED) {
onFulfilledFn();
}
if (this.status === REJECTED) {
onRejectedFn();
}
if (this.status === PENDING) {
this.onResolvedCallbacks.push(onFulfilledFn);
this.onRejectedCallbacks.push(onRejectedFn);
}
});
return currentPromise;
}
public catch(onRejected) {
return this.then(null, onRejected);
}
public finally(callback) {
return this.then(callback, callback);
}
public static resolve(value: any): Promise<any> {
return new Promise((resolve) => resolve(value));
}
public static reject(reason) {
return new Promise((_, reject) => reject(reason));
}
public static all(promises: Thenable[]): Promise<unknown> {
let count = 0;
const len = promises.length;
const results: unknown[] = [];
return new Promise((resolve, reject) => {
if (typeof promises[Symbol.iterator] !== "function") {
throw TypeError(`${promises} is not iterable`);
}
promises = Array.from(promises);
if (promises.length === 0) {
resolve(results);
}
const setValue = (value, index) => {
count++;
results[index] = value;
if (count === len) {
resolve(results);
}
};
promises.forEach((promise, index) => {
promise.then((value) => setValue(value, index), reject);
});
});
}
public static allSettled(promises: Thenable[]) {
let count = 0;
const len = promises.length;
const results: unknown[] = [];
return new Promise((resolve, reject) => {
if (typeof promises[Symbol.iterator] !== "function") {
throw TypeError(`${promises} is not iterable`);
}
promises = Array.from(promises);
if (promises.length === 0) {
resolve(results);
}
const setValue = (value, index) => {
count++;
results[index] = value;
if (count === len) {
resolve(results);
}
};
promises.forEach((promise, index) => {
promise.then(
(value) =>
setValue(
{
status: FULFILLED,
value,
},
index
),
(reason) =>
setValue(
{
status: REJECTED,
reason,
},
index
)
);
});
});
}
public static race(promises: Thenable[]) {
return new Promise((resolve, reject) => {
promises.forEach((promise) => promise.then(resolve, reject));
});
}
}
function promiseResolver(
currentPromise: unknown,
onFulfilledReturnValue: unknown,
resolve: PromiseResolveFn,
reject: PromiseRejectFn
) {
if (currentPromise === onFulfilledReturnValue) {
reject(new TypeError("can not resolve promise self"));
} else {
let resolveOrRejectHasCalled = false;
if (isObjectOrFunction(onFulfilledReturnValue)) {
try {
const then = onFulfilledReturnValue.then;
if (typeof then === "function") {
then.call(
onFulfilledReturnValue,
(newMaybePromise) => {
if (resolveOrRejectHasCalled) return;
resolveOrRejectHasCalled = true;
promiseResolver(
onFulfilledReturnValue,
newMaybePromise,
resolve,
reject
);
},
(newMaybePromiseReason) => {
if (resolveOrRejectHasCalled) return;
resolveOrRejectHasCalled = true;
reject(newMaybePromiseReason);
}
);
} else {
resolve(onFulfilledReturnValue);
}
} catch (err) {
if (resolveOrRejectHasCalled) return;
resolveOrRejectHasCalled = true;
reject(err);
}
} else {
resolve(onFulfilledReturnValue);
}
}
}
/**
* 根据PromiseA+规范,这里的 runMircoTask 实现是依赖于平台的。在浏览器环境可用 queueMicrotask 和 MutationObserver 来模拟实现。如果没有对应的微任务接口则使用setTimeout 也是符合规范的。
*/
function runMircoTask(fn: Function) {
queueMicrotask(fn);
}
function isObjectOrFunction(x: unknown): x is Object | Function {
return x !== null && (typeof x === "function" || typeof x === "object");
}
export { Promise };
import { describe, test, expect, vi } from "vitest";
import { Promise } from "./Promise";
describe("Promise", () => {
describe(".resolve()", () => {
test("should resolve a value", async () => {
console.log("hcl", Promise, Promise.resolve);
const result = await Promise.resolve("success");
expect(result).toBe("success");
});
});
describe(".reject()", () => {
test("should reject a value", async () => {
try {
await Promise.reject("failure");
} catch (error) {
expect(error).toBe("failure");
}
});
});
describe(".race()", () => {
test("should resolve with the first resolved promise", async () => {
const promise1 = new Promise((resolve) =>
setTimeout(resolve, 500, "one")
);
const promise2 = new Promise((resolve) =>
setTimeout(resolve, 100, "two")
);
const result = await Promise.race([promise1, promise2]);
expect(result).toBe("two");
});
test("should reject with the first rejected promise", async () => {
const promise1 = new Promise((_, reject) =>
setTimeout(reject, 500, "one")
);
const promise2 = new Promise((_, reject) =>
setTimeout(reject, 100, "two")
);
try {
await Promise.race([promise1, promise2]);
} catch (error) {
expect(error).toBe("two");
}
});
});
describe(".all()", () => {
test("should resolve when all promises have resolved", async () => {
const promise1 = Promise.resolve("one");
const promise2 = Promise.resolve("two");
const result = await Promise.all([promise1, promise2]);
expect(result).toEqual(["one", "two"]);
});
test("should reject when any promise rejects", async () => {
const promise1 = Promise.resolve("one");
const promise2 = Promise.reject("two");
try {
await Promise.all([promise1, promise2]);
} catch (error) {
console.log(error);
expect(error).toBe("two");
}
});
});
describe(".catch()", () => {
test("should handle rejections", async () => {
const promise = Promise.reject("error");
const result = await promise.catch((error) => error);
expect(result).toBe("error");
});
});
describe(".finally()", () => {
test("should be called regardless of success or failure", async () => {
const callback = vi.fn();
await Promise.resolve("success").finally(callback);
expect(callback).toHaveBeenCalledTimes(1);
callback.mockClear();
await Promise.reject("failure")
.catch(() => {})
.finally(callback);
expect(callback).toHaveBeenCalledTimes(1);
});
});
describe(".allSettled()", () => {
test("should resolve when all promises have settled", async () => {
const promise1 = Promise.resolve("one");
const promise2 = Promise.reject("two");
const results = await Promise.allSettled([promise1, promise2]);
expect(results).toEqual([
{ status: "fulfilled", value: "one" },
{ status: "rejected", reason: "two" },
]);
});
});
});