diff --git a/src/IConnectionOptions.ts b/src/IConnectionOptions.ts index b3299e3..f7dec18 100644 --- a/src/IConnectionOptions.ts +++ b/src/IConnectionOptions.ts @@ -17,6 +17,7 @@ export default interface IConnectionOptions { schemaName: string; ssl: boolean; skipTables: string[]; + onlyTables: string[]; } export function getDefaultConnectionOptions(): IConnectionOptions { @@ -30,6 +31,7 @@ export function getDefaultConnectionOptions(): IConnectionOptions { schemaName: "", ssl: false, skipTables: [], + onlyTables: [], }; return connectionOptions; } diff --git a/src/drivers/AbstractDriver.ts b/src/drivers/AbstractDriver.ts index fb9d855..7023da7 100644 --- a/src/drivers/AbstractDriver.ts +++ b/src/drivers/AbstractDriver.ts @@ -69,8 +69,7 @@ export default abstract class AbstractDriver { public abstract GetAllTablesQuery: ( schema: string, - dbNames: string, - tableNames: string[] + dbNames: string ) => Promise< { TABLE_SCHEMA: string; @@ -188,8 +187,7 @@ export default abstract class AbstractDriver { ); dbModel = await this.GetAllTables( sqlEscapedSchema, - connectionOptions.databaseName, - connectionOptions.skipTables + connectionOptions.databaseName ); await this.GetCoulmnsFromEntity( dbModel, @@ -210,21 +208,35 @@ export default abstract class AbstractDriver { ); await this.DisconnectFromServer(); dbModel = AbstractDriver.FindManyToManyRelations(dbModel); + dbModel = AbstractDriver.FilterGeneratedTables( + dbModel, + connectionOptions.skipTables, + connectionOptions.onlyTables + ); return dbModel; } + static FilterGeneratedTables( + dbModel: Entity[], + skipTables: string[], + onlyTables: string[] + ): Entity[] { + return dbModel + .filter((table) => !skipTables.includes(table.sqlName)) + .filter( + (table) => + onlyTables.length === 0 || + onlyTables.includes(table.sqlName) + ); + } + public abstract async ConnectToServer(connectionOptons: IConnectionOptions); public async GetAllTables( schema: string, - dbNames: string, - tableNames: string[] + dbNames: string ): Promise { - const response = await this.GetAllTablesQuery( - schema, - dbNames, - tableNames - ); + const response = await this.GetAllTablesQuery(schema, dbNames); const ret: Entity[] = [] as Entity[]; response.forEach((val) => { ret.push({ diff --git a/src/drivers/MssqlDriver.ts b/src/drivers/MssqlDriver.ts index ccc846f..01dcb66 100644 --- a/src/drivers/MssqlDriver.ts +++ b/src/drivers/MssqlDriver.ts @@ -37,16 +37,8 @@ export default class MssqlDriver extends AbstractDriver { } } - public GetAllTablesQuery = async ( - schema: string, - dbNames: string, - tableNames: string[] - ) => { + public GetAllTablesQuery = async (schema: string, dbNames: string) => { const request = new this.MSSQL.Request(this.Connection); - const tableCondition = - tableNames.length > 0 - ? ` AND NOT TABLE_NAME IN ('${tableNames.join("','")}')` - : ""; const response: { TABLE_SCHEMA: string; TABLE_NAME: string; @@ -56,7 +48,7 @@ export default class MssqlDriver extends AbstractDriver { `SELECT TABLE_SCHEMA,TABLE_NAME, table_catalog as "DB_NAME" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' and TABLE_SCHEMA in (${schema}) AND TABLE_CATALOG in (${MssqlDriver.escapeCommaSeparatedList( dbNames - )}) ${tableCondition}` + )})` ) ).recordset; return response; diff --git a/src/drivers/MysqlDriver.ts b/src/drivers/MysqlDriver.ts index 95d7f17..921c310 100644 --- a/src/drivers/MysqlDriver.ts +++ b/src/drivers/MysqlDriver.ts @@ -39,15 +39,7 @@ export default class MysqlDriver extends AbstractDriver { } } - public GetAllTablesQuery = async ( - schema: string, - dbNames: string, - tableNames: string[] - ) => { - const tableCondition = - tableNames.length > 0 - ? ` AND NOT TABLE_NAME IN ('${tableNames.join("','")}')` - : ""; + public GetAllTablesQuery = async (schema: string, dbNames: string) => { const response = this.ExecQuery<{ TABLE_SCHEMA: string; TABLE_NAME: string; @@ -57,7 +49,7 @@ export default class MysqlDriver extends AbstractDriver { WHERE table_type='BASE TABLE' AND table_schema IN (${MysqlDriver.escapeCommaSeparatedList( dbNames - )}) ${tableCondition}`); + )})`); return response; }; diff --git a/src/drivers/OracleDriver.ts b/src/drivers/OracleDriver.ts index 2282887..9636349 100644 --- a/src/drivers/OracleDriver.ts +++ b/src/drivers/OracleDriver.ts @@ -38,22 +38,14 @@ export default class OracleDriver extends AbstractDriver { } } - public GetAllTablesQuery = async ( - schema: string, - dbNames: string, - tableNames: string[] - ) => { - const tableCondition = - tableNames.length > 0 - ? ` AND NOT TABLE_NAME IN ('${tableNames.join("','")}')` - : ""; + public GetAllTablesQuery = async (schema: string, dbNames: string) => { const response = ( await this.Connection.execute<{ TABLE_SCHEMA: string; TABLE_NAME: string; DB_NAME: string; }>( - `SELECT NULL AS TABLE_SCHEMA, TABLE_NAME, NULL AS DB_NAME FROM all_tables WHERE owner = (select user from dual) ${tableCondition}` + `SELECT NULL AS TABLE_SCHEMA, TABLE_NAME, NULL AS DB_NAME FROM all_tables WHERE owner = (select user from dual)` ) ).rows!; return response; diff --git a/src/drivers/PostgresDriver.ts b/src/drivers/PostgresDriver.ts index c3f4f86..4f4b285 100644 --- a/src/drivers/PostgresDriver.ts +++ b/src/drivers/PostgresDriver.ts @@ -37,22 +37,14 @@ export default class PostgresDriver extends AbstractDriver { } } - public GetAllTablesQuery = async ( - schema: string, - dbNames: string, - tableNames: string[] - ) => { - const tableCondition = - tableNames.length > 0 - ? ` AND NOT table_name IN ('${tableNames.join("','")}')` - : ""; + public GetAllTablesQuery = async (schema: string, dbNames: string) => { const response: { TABLE_SCHEMA: string; TABLE_NAME: string; DB_NAME: string; }[] = ( await this.Connection.query( - `SELECT table_schema as "TABLE_SCHEMA",table_name as "TABLE_NAME", table_catalog as "DB_NAME" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND table_schema in (${schema}) ${tableCondition}` + `SELECT table_schema as "TABLE_SCHEMA",table_name as "TABLE_NAME", table_catalog as "DB_NAME" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND table_schema in (${schema})` ) ).rows; return response; diff --git a/src/drivers/SqliteDriver.ts b/src/drivers/SqliteDriver.ts index 1883a97..abe1fa2 100644 --- a/src/drivers/SqliteDriver.ts +++ b/src/drivers/SqliteDriver.ts @@ -46,17 +46,12 @@ export default class SqliteDriver extends AbstractDriver { public async GetAllTables( schema: string, - dbNames: string, - tableNames: string[] + dbNames: string ): Promise { const ret: Entity[] = [] as Entity[]; - const tableCondition = - tableNames.length > 0 - ? ` AND NOT tbl_name IN ('${tableNames.join("','")}')` - : ""; // eslint-disable-next-line camelcase const rows = await this.ExecQuery<{ tbl_name: string; sql: string }>( - `SELECT tbl_name, sql FROM "sqlite_master" WHERE "type" = 'table' AND name NOT LIKE 'sqlite_%' ${tableCondition}` + `SELECT tbl_name, sql FROM "sqlite_master" WHERE "type" = 'table' AND name NOT LIKE 'sqlite_%'` ); rows.forEach((val) => { if (val.sql.includes("AUTOINCREMENT")) { diff --git a/src/index.ts b/src/index.ts index 8aab4a3..122cbcd 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,8 +7,8 @@ import IConnectionOptions, { import IGenerationOptions, { getDefaultGenerationOptions, } from "./IGenerationOptions"; - import fs = require("fs-extra"); + import inquirer = require("inquirer"); import path = require("path"); @@ -270,6 +270,12 @@ function checkYargsParameters(options: options): options { describe: "Skip schema generation for specific tables. You can pass multiple values separated by comma", }, + tables: { + string: true, + default: options.connectionOptions.onlyTables.join(","), + describe: + "Generate specific tables. You can pass multiple values separated by comma", + }, strictMode: { choices: ["none", "?", "!"], default: options.generationOptions.strictMode, @@ -305,7 +311,12 @@ function checkYargsParameters(options: options): options { if (skipTables.length === 1 && skipTables[0] === "") { skipTables = []; // #252 } + let tables = argv.tables.split(","); + if (tables.length === 1 && tables[0] === "") { + tables = []; + } options.connectionOptions.skipTables = skipTables; + options.connectionOptions.onlyTables = tables; options.generationOptions.activeRecord = argv.a; options.generationOptions.generateConstructor = argv.generateConstructor; options.generationOptions.convertCaseEntity = argv.ce as IGenerationOptions["convertCaseEntity"]; @@ -441,23 +452,43 @@ async function useInquirer(options: options): Promise { ? "All of them" : "Ignore specific tables", message: "Generate schema for tables:", - choices: ["All of them", "Ignore specific tables"], + choices: [ + "All of them", + "Ignore specific tables", + "Select specific tables", + ], name: "specyficTables", type: "list", }, ]) ).specyficTables; - if (ignoreSpecyficTables === "Ignore specific tables") { - const { tableNames } = await inquirer.prompt({ - default: options.connectionOptions.skipTables.join(","), - message: "Table names(separated by comma)", - name: "tableNames", - type: "input", - }); - options.connectionOptions.skipTables = tableNames.split(","); - } else { - options.connectionOptions.skipTables = []; - } + + const optionsMapper = { + "All of them": () => { + options.connectionOptions.skipTables = []; + options.connectionOptions.onlyTables = []; + }, + "Ignore specific tables": async () => { + const { tableNames } = await inquirer.prompt({ + default: options.connectionOptions.skipTables.join(","), + message: "Table names(separated by comma)", + name: "tableNames", + type: "input", + }); + options.connectionOptions.skipTables = tableNames.split(","); + }, + "Select specific tables": async () => { + const { tableNames } = await inquirer.prompt({ + default: options.connectionOptions.onlyTables.join(","), + message: "Table names(separated by comma)", + name: "tableNames", + type: "input", + }); + options.connectionOptions.onlyTables = tableNames.split(","); + }, + }; + + await optionsMapper[ignoreSpecyficTables](); options.generationOptions.resultsPath = ( await inquirer.prompt([ diff --git a/test/integration/runTestsFromPath.test.ts b/test/integration/runTestsFromPath.test.ts index f645fa0..0eb7af6 100644 --- a/test/integration/runTestsFromPath.test.ts +++ b/test/integration/runTestsFromPath.test.ts @@ -38,22 +38,57 @@ describe("TypeOrm examples", async () => { const testPartialPath = "test/integration/examples"; await runTestsFromPath(testPartialPath, false); }); +describe("Filtering tables", async () => { + const testPartialPath = "test/integration/examples"; + it("skipTables",async ()=>{ + const dbDrivers: string[] = GTU.getEnabledDbDrivers(); + const modelGenerationPromises = dbDrivers.map(async dbDriver => { + const { + generationOptions, + driver, + connectionOptions + } = await prepareTestRuns(testPartialPath, "sample2-one-to-one", dbDriver); + const dbModel = await dataCollectionPhase( + driver, + {...connectionOptions, + skipTables:["Post"]}, + generationOptions + ); + expect(dbModel.length).to.equal(6); + // eslint-disable-next-line no-unused-expressions + expect(dbModel.find(x=>x.sqlName==="Post")).to.be.undefined; + }); + await Promise.all(modelGenerationPromises); + }); + it("onlyTables",async ()=>{ + const dbDrivers: string[] = GTU.getEnabledDbDrivers(); + const modelGenerationPromises = dbDrivers.map(async dbDriver => { + const { + generationOptions, + driver, + connectionOptions + } = await prepareTestRuns(testPartialPath, "sample2-one-to-one", dbDriver); + const dbModel = await dataCollectionPhase( + driver, + {...connectionOptions, + onlyTables:["Post"]}, + generationOptions + ); + expect(dbModel.length).to.equal(1); + // eslint-disable-next-line no-unused-expressions + expect(dbModel.find(x=>x.sqlName==="Post")).to.not.be.undefined; + }); + await Promise.all(modelGenerationPromises); + }) + +}); async function runTestsFromPath( testPartialPath: string, isDbSpecific: boolean ) { - const resultsPath = path.resolve(process.cwd(), `output`); - if (!fs.existsSync(resultsPath)) { - fs.mkdirSync(resultsPath); - } const dbDrivers: string[] = GTU.getEnabledDbDrivers(); - dbDrivers.forEach(dbDriver => { - const newDirPath = path.resolve(resultsPath, dbDriver); - if (!fs.existsSync(newDirPath)) { - fs.mkdirSync(newDirPath); - } - }); + createOutputDirs(dbDrivers); const files = fs.readdirSync(path.resolve(process.cwd(), testPartialPath)); if (isDbSpecific) { await runTest(dbDrivers, testPartialPath, files); @@ -64,6 +99,19 @@ async function runTestsFromPath( } } +function createOutputDirs(dbDrivers: string[]) { + const resultsPath = path.resolve(process.cwd(), `output`); + if (!fs.existsSync(resultsPath)) { + fs.mkdirSync(resultsPath); + } + dbDrivers.forEach(dbDriver => { + const newDirPath = path.resolve(resultsPath, dbDriver); + if (!fs.existsSync(newDirPath)) { + fs.mkdirSync(newDirPath); + } + }); +} + function runTestForMultipleDrivers( testName: string, dbDrivers: string[], @@ -350,7 +398,8 @@ async function prepareTestRuns( databaseType: "mysql", schemaName: "ignored", ssl: yn(process.env.MYSQL_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [], }; break; case "mariadb": @@ -363,7 +412,8 @@ async function prepareTestRuns( databaseType: "mariadb", schemaName: "ignored", ssl: yn(process.env.MARIADB_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [] }; break; diff --git a/test/utils/GeneralTestUtils.ts b/test/utils/GeneralTestUtils.ts index 813ebd1..763225d 100644 --- a/test/utils/GeneralTestUtils.ts +++ b/test/utils/GeneralTestUtils.ts @@ -31,7 +31,8 @@ export async function createMSSQLModels( databaseType: "mssql", schemaName: "dbo,sch1,sch2", ssl: yn(process.env.MSSQL_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [] }; await driver.ConnectToServer(connectionOptions); connectionOptions.databaseName = String(process.env.MSSQL_Database); @@ -83,7 +84,8 @@ export async function createPostgresModels( databaseType: "postgres", schemaName: "public,sch1,sch2", ssl: yn(process.env.POSTGRES_SSL, { default: false }), - skipTables: ["spatial_ref_sys"] + skipTables: ["spatial_ref_sys"], + onlyTables: [] }; await driver.ConnectToServer(connectionOptions); connectionOptions.databaseName = String(process.env.POSTGRES_Database); @@ -134,7 +136,8 @@ export async function createSQLiteModels( databaseType: "sqlite", schemaName: "", ssl: false, - skipTables: [] + skipTables: [], + onlyTables: [] }; const connOpt: ConnectionOptions = { @@ -169,7 +172,8 @@ export async function createMysqlModels( databaseType: "mysql", schemaName: "ignored", ssl: yn(process.env.MYSQL_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [] }; await driver.ConnectToServer(connectionOptions); @@ -212,7 +216,8 @@ export async function createMariaDBModels( databaseType: "mariadb", schemaName: "ignored", ssl: yn(process.env.MARIADB_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [] }; await driver.ConnectToServer(connectionOptions); @@ -257,7 +262,8 @@ export async function createOracleDBModels( databaseType: "oracle", schemaName: String(process.env.ORACLE_Username), ssl: yn(process.env.ORACLE_SSL, { default: false }), - skipTables: [] + skipTables: [], + onlyTables: [] }; await driver.ConnectToServer(connectionOptions); connectionOptions.user = String(process.env.ORACLE_Username);